1111
1212import org .apache .lucene .index .FloatVectorValues ;
1313import org .apache .lucene .util .VectorUtil ;
14+ import org .apache .lucene .util .hnsw .IntToIntFunction ;
15+ import org .elasticsearch .index .codec .vectors .SampleReader ;
1416import org .elasticsearch .simdvec .ESVectorUtil ;
1517
1618import java .io .IOException ;
@@ -74,12 +76,12 @@ static float[][] pickInitialCentroids(FloatVectorValues vectors, int centroidCou
7476 return centroids ;
7577 }
7678
77- private boolean stepLloyd (
79+ private static boolean stepLloyd (
7880 FloatVectorValues vectors ,
81+ IntToIntFunction translateOrd ,
7982 float [][] centroids ,
8083 float [][] nextCentroids ,
8184 int [] assignments ,
82- int sampleSize ,
8385 List <int []> neighborhoods
8486 ) throws IOException {
8587 boolean changed = false ;
@@ -90,17 +92,18 @@ private boolean stepLloyd(
9092 Arrays .fill (nextCentroid , 0.0f );
9193 }
9294
93- for (int i = 0 ; i < sampleSize ; i ++) {
94- float [] vector = vectors .vectorValue (i );
95- final int assignment = assignments [i ];
95+ for (int idx = 0 ; idx < vectors .size (); idx ++) {
96+ float [] vector = vectors .vectorValue (idx );
97+ int vectorOrd = translateOrd .apply (idx );
98+ final int assignment = assignments [vectorOrd ];
9699 final int bestCentroidOffset ;
97100 if (neighborhoods != null ) {
98101 bestCentroidOffset = getBestCentroidFromNeighbours (centroids , vector , assignment , neighborhoods .get (assignment ));
99102 } else {
100103 bestCentroidOffset = getBestCentroid (centroids , vector );
101104 }
102105 if (assignment != bestCentroidOffset ) {
103- assignments [i ] = bestCentroidOffset ;
106+ assignments [vectorOrd ] = bestCentroidOffset ;
104107 changed = true ;
105108 }
106109 centroidCounts [bestCentroidOffset ]++;
@@ -121,7 +124,7 @@ private boolean stepLloyd(
121124 return changed ;
122125 }
123126
124- int getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int centroidIdx , int [] centroidOffsets ) {
127+ private static int getBestCentroidFromNeighbours (float [][] centroids , float [] vector , int centroidIdx , int [] centroidOffsets ) {
125128 int bestCentroidOffset = centroidIdx ;
126129 assert centroidIdx >= 0 && centroidIdx < centroids .length ;
127130 float minDsq = VectorUtil .squareDistance (vector , centroids [centroidIdx ]);
@@ -135,7 +138,7 @@ int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centr
135138 return bestCentroidOffset ;
136139 }
137140
138- int getBestCentroid (float [][] centroids , float [] vector ) {
141+ private static int getBestCentroid (float [][] centroids , float [] vector ) {
139142 int bestCentroidOffset = 0 ;
140143 float minDsq = Float .MAX_VALUE ;
141144 for (int i = 0 ; i < centroids .length ; i ++) {
@@ -281,24 +284,34 @@ void cluster(FloatVectorValues vectors, KMeansIntermediate kMeansIntermediate, b
281284 }
282285 }
283286
284- void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , List <int []> neighborhoods ) throws IOException {
287+ private void cluster (FloatVectorValues vectors , KMeansIntermediate kMeansIntermediate , List <int []> neighborhoods ) throws IOException {
285288 float [][] centroids = kMeansIntermediate .centroids ();
286289 int k = centroids .length ;
287290 int n = vectors .size ();
288291
289292 if (k == 1 || k >= n ) {
290293 return ;
291294 }
292-
295+ IntToIntFunction translateOrd = i -> i ;
296+ FloatVectorValues sampledVectors = vectors ;
297+ if (sampleSize < n ) {
298+ sampledVectors = SampleReader .createSampleReader (vectors , sampleSize , 42L );
299+ translateOrd = sampledVectors ::ordToDoc ;
300+ }
293301 int [] assignments = kMeansIntermediate .assignments ();
294302 assert assignments .length == n ;
295303 float [][] nextCentroids = new float [centroids .length ][vectors .dimension ()];
296304 for (int i = 0 ; i < maxIterations ; i ++) {
297- if (stepLloyd (vectors , centroids , nextCentroids , assignments , sampleSize , neighborhoods ) == false ) {
305+ // This is potentially sampled, so we need to translate ordinals
306+ if (stepLloyd (sampledVectors , translateOrd , centroids , nextCentroids , assignments , neighborhoods ) == false ) {
298307 break ;
299308 }
300309 }
301- stepLloyd (vectors , centroids , nextCentroids , assignments , vectors .size (), neighborhoods );
310+ // If we were sampled, do a once over the full set of vectors to finalize the centroids
311+ if (sampleSize < n ) {
312+ // No ordinal translation needed here, we are using the full set of vectors
313+ stepLloyd (vectors , i -> i , centroids , nextCentroids , assignments , neighborhoods );
314+ }
302315 }
303316
304317 /**
0 commit comments