Skip to content

Commit 58c5991

Browse files
committed
iter
1 parent 4280682 commit 58c5991

File tree

3 files changed

+74
-36
lines changed

3 files changed

+74
-36
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 59 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
1313
import org.apache.lucene.index.FieldInfo;
1414
import org.apache.lucene.index.FloatVectorValues;
15+
import org.apache.lucene.index.MergeState;
1516
import org.apache.lucene.index.SegmentWriteState;
1617
import org.apache.lucene.internal.hppc.IntArrayList;
1718
import org.apache.lucene.store.IndexOutput;
@@ -47,7 +48,7 @@ public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVec
4748
}
4849

4950
@Override
50-
protected long[] buildAndWritePostingsLists(
51+
long[] buildAndWritePostingsLists(
5152
FieldInfo fieldInfo,
5253
CentroidSupplier centroidSupplier,
5354
FloatVectorValues floatVectorValues,
@@ -203,25 +204,24 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
203204
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
204205
// TODO do we want to store these distances as well for future use?
205206
// TODO: this sorting operation tanks recall for some reason, works fine for small numbers of vectors like in single segment
206-
// need to investigate this further
207-
// float[] distances = new float[centroids.length];
208-
// for (int i = 0; i < centroids.length; i++) {
209-
// distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
210-
// }
211-
// // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest
212-
// // (largest)
213-
// for (int i = 0; i < centroids.length; i++) {
214-
// for (int j = i + 1; j < centroids.length; j++) {
215-
// if (distances[i] > distances[j]) {
216-
// float[] tmp = centroids[i];
217-
// centroids[i] = centroids[j];
218-
// centroids[j] = tmp;
219-
// float tmpDistance = distances[i];
220-
// distances[i] = distances[j];
221-
// distances[j] = tmpDistance;
222-
// }
223-
// }
224-
// }
207+
// need to investigate this further
208+
float[] distances = new float[centroids.length];
209+
for (int i = 0; i < centroids.length; i++) {
210+
distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid);
211+
}
212+
// sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest (largest)
213+
for (int i = 0; i < centroids.length; i++) {
214+
for (int j = i + 1; j < centroids.length; j++) {
215+
if (distances[i] > distances[j]) {
216+
float[] tmp = centroids[i];
217+
centroids[i] = centroids[j];
218+
centroids[j] = tmp;
219+
float tmpDistance = distances[i];
220+
distances[i] = distances[j];
221+
distances[j] = tmpDistance;
222+
}
223+
}
224+
}
225225
for (float[] centroid : centroids) {
226226
System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length);
227227
OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize(
@@ -239,6 +239,41 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
239239
}
240240
}
241241

242+
CentroidAssignments calculateAndWriteCentroids(
243+
FieldInfo fieldInfo,
244+
FloatVectorValues floatVectorValues,
245+
IndexOutput centroidOutput,
246+
MergeState mergeState,
247+
float[] globalCentroid
248+
) throws IOException {
249+
// TODO: take advantage of prior generated clusters from mergeState in the future
250+
return calculateAndWriteCentroids(
251+
fieldInfo,
252+
floatVectorValues,
253+
centroidOutput,
254+
mergeState.infoStream,
255+
globalCentroid,
256+
false
257+
);
258+
}
259+
260+
CentroidAssignments calculateAndWriteCentroids(
261+
FieldInfo fieldInfo,
262+
FloatVectorValues floatVectorValues,
263+
IndexOutput centroidOutput,
264+
InfoStream infoStream,
265+
float[] globalCentroid
266+
) throws IOException {
267+
return calculateAndWriteCentroids(
268+
fieldInfo,
269+
floatVectorValues,
270+
centroidOutput,
271+
infoStream,
272+
globalCentroid,
273+
true
274+
);
275+
}
276+
242277
/**
243278
* Calculate the centroids for the given field and write them to the given centroid output.
244279
* We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments
@@ -252,8 +287,7 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
252287
* @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed
253288
* @throws IOException if an I/O error occurs
254289
*/
255-
@Override
256-
protected CentroidAssignments calculateAndWriteCentroids(
290+
CentroidAssignments calculateAndWriteCentroids(
257291
FieldInfo fieldInfo,
258292
FloatVectorValues floatVectorValues,
259293
IndexOutput centroidOutput,
@@ -270,10 +304,9 @@ protected CentroidAssignments calculateAndWriteCentroids(
270304
short[] assignments = kMeansResult.assignments();
271305
short[] soarAssignments = kMeansResult.soarAssignments();
272306

273-
// TODO: previously for flush we were doing this over the vectors not the centroids,
274-
// right off this produces good recall but need to do further evaluation
275-
// VectorUtil.calculateCentroid(fieldWriter.delegate().getVectors(), globalCentroid);
276-
// TODO: push this logic into vector util
307+
// TODO: for flush we are doing this over the vectors and here centroids which seems duplicative
308+
// preliminary tests suggest recall is good using only centroids but need to do further evaluation
309+
// TODO: push this logic into vector util?
277310
for (float[] centroid : centroids) {
278311
for (int j = 0; j < centroid.length; j++) {
279312
globalCentroid[j] += centroid[j];

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,16 +121,23 @@ public final KnnFieldVectorsWriter<?> addField(FieldInfo fieldInfo) throws IOExc
121121
return rawVectorDelegate;
122122
}
123123

124-
protected abstract CentroidAssignments calculateAndWriteCentroids(
124+
abstract CentroidAssignments calculateAndWriteCentroids(
125+
FieldInfo fieldInfo,
126+
FloatVectorValues floatVectorValues,
127+
IndexOutput centroidOutput,
128+
MergeState mergeState,
129+
float[] globalCentroid
130+
) throws IOException;
131+
132+
abstract CentroidAssignments calculateAndWriteCentroids(
125133
FieldInfo fieldInfo,
126134
FloatVectorValues floatVectorValues,
127135
IndexOutput centroidOutput,
128136
InfoStream infoStream,
129-
float[] globalCentroid,
130-
boolean cacheCentroids
137+
float[] globalCentroid
131138
) throws IOException;
132139

133-
protected abstract long[] buildAndWritePostingsLists(
140+
abstract long[] buildAndWritePostingsLists(
134141
FieldInfo fieldInfo,
135142
CentroidSupplier centroidSupplier,
136143
FloatVectorValues floatVectorValues,
@@ -154,8 +161,7 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException {
154161
floatVectorValues,
155162
ivfCentroids,
156163
segmentWriteState.infoStream,
157-
globalCentroid,
158-
true
164+
globalCentroid
159165
);
160166

161167
CentroidSupplier centroidSupplier = new OnHeapCentroidSupplier(centroidAssignments.cachedCentroids());
@@ -267,9 +273,8 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro
267273
fieldInfo,
268274
floatVectorValues,
269275
centroidTemp,
270-
mergeState.infoStream,
271-
calculatedGlobalCentroid,
272-
false
276+
mergeState,
277+
calculatedGlobalCentroid
273278
);
274279
numCentroids = centroidAssignments.numCentroids();
275280

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ KMeansResult kMeansHierarchical(final FloatVectorValuesSlice vectors, final int
8383
short[] assignments = new short[vectors.size()];
8484

8585
KMeans kmeans = new KMeans(m, maxIterations);
86-
float[][] centroids = kmeans.pickInitialCentroids(vectors, m, k);
86+
float[][] centroids = KMeans.pickInitialCentroids(vectors, m, k);
8787
KMeansResult kMeansResult = new KMeansResult(centroids);
8888
kmeans.cluster(vectors, kMeansResult);
8989

0 commit comments

Comments
 (0)