1212import org .apache .lucene .codecs .hnsw .FlatVectorsWriter ;
1313import org .apache .lucene .index .FieldInfo ;
1414import org .apache .lucene .index .FloatVectorValues ;
15+ import org .apache .lucene .index .MergeState ;
1516import org .apache .lucene .index .SegmentWriteState ;
1617import org .apache .lucene .internal .hppc .IntArrayList ;
1718import org .apache .lucene .store .IndexOutput ;
@@ -47,7 +48,7 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec
4748 }
4849
4950 @ Override
50- protected long [] buildAndWritePostingsLists (
51+ long [] buildAndWritePostingsLists (
5152 FieldInfo fieldInfo ,
5253 CentroidSupplier centroidSupplier ,
5354 FloatVectorValues floatVectorValues ,
@@ -203,25 +204,24 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
203204 float [] centroidScratch = new float [fieldInfo .getVectorDimension ()];
204205 // TODO do we want to store these distances as well for future use?
205206 // TODO: this sorting operation tanks recall for some reason, works fine for small numbers of vectors like in single segment
206- // need to investigate this further
207- // float[] distances = new float[centroids.length];
208- // for (int i = 0; i < centroids.length; i++) {
209- // distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
210- // }
211- // // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
212- // // (largest)
213- // for (int i = 0; i < centroids.length; i++) {
214- // for (int j = i + 1; j < centroids.length; j++) {
215- // if (distances[i] > distances[j]) {
216- // float[] tmp = centroids[i];
217- // centroids[i] = centroids[j];
218- // centroids[j] = tmp;
219- // float tmpDistance = distances[i];
220- // distances[i] = distances[j];
221- // distances[j] = tmpDistance;
222- // }
223- // }
224- // }
207+ // need to investigate this further
208+ float [] distances = new float [centroids .length ];
209+ for (int i = 0 ; i < centroids .length ; i ++) {
210+ distances [i ] = VectorUtil .squareDistance (centroids [i ], globalCentroid );
211+ }
212+ // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest (largest)
213+ for (int i = 0 ; i < centroids .length ; i ++) {
214+ for (int j = i + 1 ; j < centroids .length ; j ++) {
215+ if (distances [i ] > distances [j ]) {
216+ float [] tmp = centroids [i ];
217+ centroids [i ] = centroids [j ];
218+ centroids [j ] = tmp ;
219+ float tmpDistance = distances [i ];
220+ distances [i ] = distances [j ];
221+ distances [j ] = tmpDistance ;
222+ }
223+ }
224+ }
225225 for (float [] centroid : centroids ) {
226226 System .arraycopy (centroid , 0 , centroidScratch , 0 , centroid .length );
227227 OptimizedScalarQuantizer .QuantizationResult result = osq .scalarQuantize (
@@ -239,6 +239,41 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
239239 }
240240 }
241241
242+ CentroidAssignments calculateAndWriteCentroids (
243+ FieldInfo fieldInfo ,
244+ FloatVectorValues floatVectorValues ,
245+ IndexOutput centroidOutput ,
246+ MergeState mergeState ,
247+ float [] globalCentroid
248+ ) throws IOException {
249+ // TODO: take advantage of prior generated clusters from mergeState in the future
250+ return calculateAndWriteCentroids (
251+ fieldInfo ,
252+ floatVectorValues ,
253+ centroidOutput ,
254+ mergeState .infoStream ,
255+ globalCentroid ,
256+ false
257+ );
258+ }
259+
260+ CentroidAssignments calculateAndWriteCentroids (
261+ FieldInfo fieldInfo ,
262+ FloatVectorValues floatVectorValues ,
263+ IndexOutput centroidOutput ,
264+ InfoStream infoStream ,
265+ float [] globalCentroid
266+ ) throws IOException {
267+ return calculateAndWriteCentroids (
268+ fieldInfo ,
269+ floatVectorValues ,
270+ centroidOutput ,
271+ infoStream ,
272+ globalCentroid ,
273+ true
274+ );
275+ }
276+
242277 /**
243278 * Calculate the centroids for the given field and write them to the given centroid output.
244279 * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
@@ -252,8 +287,7 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
252287 * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
253288 * @throws IOException if an I/O error occurs
254289 */
255- @ Override
256- protected CentroidAssignments calculateAndWriteCentroids (
290+ CentroidAssignments calculateAndWriteCentroids (
257291 FieldInfo fieldInfo ,
258292 FloatVectorValues floatVectorValues ,
259293 IndexOutput centroidOutput ,
@@ -270,10 +304,9 @@ protected CentroidAssignments calculateAndWriteCentroids(
270304 short [] assignments = kMeansResult .assignments ();
271305 short [] soarAssignments = kMeansResult .soarAssignments ();
272306
273- // TODO: previously for flush we were doing this over the vectors not the centroids,
274- // right off this produces good recall but need to do further evaluation
275- // VectorUtil.calculateCentroid(fieldWriter.delegate().getVectors(), globalCentroid);
276- // TODO: push this logic into vector util
307+ // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
308+ // preliminary tests suggest recall is good using only centroids but need to do further evaluation
309+ // TODO: push this logic into vector util?
277310 for (float [] centroid : centroids ) {
278311 for (int j = 0 ; j < centroid .length ; j ++) {
279312 globalCentroid [j ] += centroid [j ];
0 commit comments