Skip to content

Commit 3e88c51

Browse files
committed
iter
1 parent 8d25046 commit 3e88c51

File tree

4 files changed

+96
-85
lines changed

4 files changed

+96
-85
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 69 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,42 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4646
super(state, rawVectorsReader);
4747
}
4848

49+
private abstract static class BaseCentroidQueryScorer implements CentroidQueryScorer {
50+
51+
// TODO can we do this in off-heap blocks?
52+
float int4QuantizedScore(
53+
float qcDist,
54+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
55+
int dims,
56+
float[] targetCorrections,
57+
int targetComponentSum,
58+
float centroidDp,
59+
VectorSimilarityFunction similarityFunction
60+
) {
61+
float ax = targetCorrections[0];
62+
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
63+
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
64+
float ay = queryCorrections.lowerInterval();
65+
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
66+
float y1 = queryCorrections.quantizedComponentSum();
67+
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
68+
if (similarityFunction == EUCLIDEAN) {
69+
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
70+
return Math.max(1 / (1f + score), 0);
71+
} else {
72+
// For cosine and max inner product, we need to apply the additional correction, which is
73+
// assumed to be the non-centered dot-product between the vector and the centroid
74+
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
75+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
76+
return VectorUtil.scaleMaxInnerProductScore(score);
77+
}
78+
return Math.max((1f + score) / 2f, 0);
79+
}
80+
}
81+
}
82+
83+
private abstract static class ParentCentroidQueryScorer extends BaseCentroidQueryScorer implements CentroidWChildrenQueryScorer {}
84+
4985
@Override
5086
CentroidQueryScorer getCentroidScorer(
5187
FieldInfo fieldInfo,
@@ -65,7 +101,7 @@ CentroidQueryScorer getCentroidScorer(
65101
fieldEntry.globalCentroid()
66102
);
67103
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
68-
return new CentroidQueryScorer() {
104+
return new BaseCentroidQueryScorer() {
69105
int currentCentroid = -1;
70106
private final float[] centroid = new float[fieldInfo.getVectorDimension()];
71107
private final float[] centroidCorrectiveValues = new float[3];
@@ -90,6 +126,7 @@ public float[] centroid(int centroidOrdinal) throws IOException {
90126
return centroid;
91127
}
92128

129+
@Override
93130
public void bulkScore(NeighborQueue queue) throws IOException {
94131
// TODO: bulk score centroids like we do with posting lists
95132
centroids.seek(quantizedCentroidsOffset);
@@ -121,44 +158,16 @@ private float score() throws IOException {
121158
fieldInfo.getVectorSimilarityFunction()
122159
);
123160
}
124-
125-
// TODO can we do this in off-heap blocks?
126-
private float int4QuantizedScore(
127-
float qcDist,
128-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
129-
int dims,
130-
float[] targetCorrections,
131-
int targetComponentSum,
132-
float centroidDp,
133-
VectorSimilarityFunction similarityFunction
134-
) {
135-
float ax = targetCorrections[0];
136-
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
137-
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
138-
float ay = queryCorrections.lowerInterval();
139-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
140-
float y1 = queryCorrections.quantizedComponentSum();
141-
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
142-
if (similarityFunction == EUCLIDEAN) {
143-
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
144-
return Math.max(1 / (1f + score), 0);
145-
} else {
146-
// For cosine and max inner product, we need to apply the additional correction, which is
147-
// assumed to be the non-centered dot-product between the vector and the centroid
148-
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
149-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
150-
return VectorUtil.scaleMaxInnerProductScore(score);
151-
}
152-
return Math.max((1f + score) / 2f, 0);
153-
}
154-
}
155161
};
156162
}
157163

158-
// FIXME: clean up duplicative code between the scorers
159164
@Override
160-
ParentCentroidQueryScorer getParentCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
161-
throws IOException {
165+
ParentCentroidQueryScorer getParentCentroidScorer(
166+
FieldInfo fieldInfo,
167+
int numParentCentroids,
168+
IndexInput centroids,
169+
float[] targetQuery
170+
) throws IOException {
162171
FieldEntry fieldEntry = fields.get(fieldInfo.number);
163172
float[] globalCentroid = fieldEntry.globalCentroid();
164173
float globalCentroidDp = fieldEntry.globalCentroidDp();
@@ -183,15 +192,15 @@ ParentCentroidQueryScorer getParentCentroidScorer(FieldInfo fieldInfo, int numCe
183192

184193
@Override
185194
public int size() {
186-
return numCentroids;
195+
return numParentCentroids;
187196
}
188197

189198
@Override
190199
public float[] centroid(int centroidOrdinal) throws IOException {
191200
throw new UnsupportedOperationException("can't score at the parent level");
192201
}
193202

194-
private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException {
203+
private void readChildDetails(int centroidOrdinal) throws IOException {
195204
if (centroidOrdinal == currentCentroid) {
196205
return;
197206
}
@@ -201,28 +210,29 @@ private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException
201210
currentCentroid = centroidOrdinal;
202211
}
203212

213+
@Override
204214
public int getChildCentroidStart(int centroidOrdinal) throws IOException {
205-
readQuantizedAndRawCentroid(centroidOrdinal);
215+
readChildDetails(centroidOrdinal);
206216
return childCentroidStart;
207217
}
208218

219+
@Override
209220
public int getChildCount(int centroidOrdinal) throws IOException {
210-
readQuantizedAndRawCentroid(centroidOrdinal);
221+
readChildDetails(centroidOrdinal);
211222
return childCount;
212223
}
213224

214225
@Override
215226
public void bulkScore(NeighborQueue queue) throws IOException {
216227
// TODO: bulk score centroids like we do with posting lists
217228
centroids.seek(0L);
218-
for (int i = 0; i < numCentroids; i++) {
229+
for (int i = 0; i < numParentCentroids; i++) {
219230
queue.add(i, score());
220231
}
221232
}
222233

223234
@Override
224235
public void bulkScore(NeighborQueue queue, int start, int end) throws IOException {
225-
// FIXME: this never gets used ... I wonder if we just need an entirely different interface for this
226236
// TODO: bulk score centroids like we do with posting lists
227237
centroids.seek(parentNodeByteSize * start);
228238
for (int i = start; i < end; i++) {
@@ -235,10 +245,10 @@ private float score() throws IOException {
235245
centroids.readFloats(centroidCorrectiveValues, 0, 3);
236246
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
237247

238-
// FIXME: move these now? to a different place in the file?
248+
// TODO: should we consider a different format such as moving these to the beginning of the file to benefit bulk read
239249
// TODO: cache these at this point when scoring since we'll likely read many of them?
240-
centroids.readInt(); // child partition start
241-
centroids.readInt(); // child partition count
250+
// child partition start, child partition count
251+
centroids.skipBytes(Integer.BYTES * 2);
242252

243253
return int4QuantizedScore(
244254
qcDist,
@@ -250,46 +260,27 @@ private float score() throws IOException {
250260
fieldInfo.getVectorSimilarityFunction()
251261
);
252262
}
253-
254-
// TODO can we do this in off-heap blocks?
255-
private float int4QuantizedScore(
256-
float qcDist,
257-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
258-
int dims,
259-
float[] targetCorrections,
260-
int targetComponentSum,
261-
float centroidDp,
262-
VectorSimilarityFunction similarityFunction
263-
) {
264-
float ax = targetCorrections[0];
265-
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
266-
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
267-
float ay = queryCorrections.lowerInterval();
268-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
269-
float y1 = queryCorrections.quantizedComponentSum();
270-
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
271-
if (similarityFunction == EUCLIDEAN) {
272-
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
273-
return Math.max(1 / (1f + score), 0);
274-
} else {
275-
// For cosine and max inner product, we need to apply the additional correction, which is
276-
// assumed to be the non-centered dot-product between the vector and the centroid
277-
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
278-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
279-
return VectorUtil.scaleMaxInnerProductScore(score);
280-
}
281-
return Math.max((1f + score) / 2f, 0);
282-
}
283-
}
284263
};
285264
}
286265

266+
@Override
267+
NeighborQueue scorePostingLists(
268+
FieldInfo fieldInfo,
269+
KnnCollector knnCollector,
270+
CentroidQueryScorer centroidQueryScorer,
271+
int nProbe,
272+
int start,
273+
int count
274+
) throws IOException {
275+
NeighborQueue neighborQueue = new NeighborQueue(count, true);
276+
centroidQueryScorer.bulkScore(neighborQueue, start, start + count);
277+
return neighborQueue;
278+
}
279+
287280
@Override
288281
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
289282
throws IOException {
290-
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
291-
centroidQueryScorer.bulkScore(neighborQueue);
292-
return neighborQueue;
283+
return scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe, 0, centroidQueryScorer.size());
293284
}
294285

295286
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@ CentroidAssignments calculateAndWriteCentroids(
296296

297297
List<CentroidPartition> centroidPartitions = new ArrayList<>();
298298

299+
// TODO: make this configurable
299300
if (centroids.length > DEFAULT_VECTORS_PER_CLUSTER) {
300301
// TODO: sort by global centroids as well
301302
// TODO: have this take a function instead of just an int[] for sorting

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
8989
}
9090
}
9191

92-
abstract ParentCentroidQueryScorer getParentCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
93-
throws IOException;
92+
abstract CentroidWChildrenQueryScorer getParentCentroidScorer(
93+
FieldInfo fieldInfo,
94+
int numCentroids,
95+
IndexInput centroids,
96+
float[] target
97+
) throws IOException;
9498

9599
abstract CentroidQueryScorer getCentroidScorer(
96100
FieldInfo fieldInfo,
@@ -272,7 +276,7 @@ public final void search(String field, float[] target, KnnCollector knnCollector
272276
long expectedDocs = 0;
273277
long actualDocs = 0;
274278

275-
ParentCentroidQueryScorer parentCentroidQueryScorer = getParentCentroidScorer(
279+
CentroidWChildrenQueryScorer parentCentroidQueryScorer = getParentCentroidScorer(
276280
fieldInfo,
277281
entry.parentCentroidCount,
278282
entry.centroidSlice(ivfCentroids),
@@ -300,9 +304,14 @@ public final void search(String field, float[] target, KnnCollector knnCollector
300304
childCentroidCount = parentCentroidQueryScorer.getChildCount(parentCentroidOrdinal);
301305
}
302306

303-
// FIXME: create a start / end aware scorePostingLists
304-
NeighborQueue centroidQueue = new NeighborQueue(childCentroidCount, true);
305-
centroidQueryScorer.bulkScore(centroidQueue, childCentroidOrdinal, childCentroidOrdinal + childCentroidCount);
307+
NeighborQueue centroidQueue = scorePostingLists(
308+
fieldInfo,
309+
knnCollector,
310+
centroidQueryScorer,
311+
nProbe,
312+
childCentroidOrdinal,
313+
childCentroidCount
314+
);
306315

307316
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
308317
// initially we visit only the "centroids to search"
@@ -345,6 +354,15 @@ public final void search(String field, byte[] target, KnnCollector knnCollector,
345354
}
346355
}
347356

357+
abstract NeighborQueue scorePostingLists(
358+
FieldInfo fieldInfo,
359+
KnnCollector knnCollector,
360+
CentroidQueryScorer centroidQueryScorer,
361+
int nProbe,
362+
int start,
363+
int end
364+
) throws IOException;
365+
348366
abstract NeighborQueue scorePostingLists(
349367
FieldInfo fieldInfo,
350368
KnnCollector knnCollector,
@@ -375,7 +393,7 @@ IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
375393
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
376394
throws IOException;
377395

378-
interface ParentCentroidQueryScorer extends CentroidQueryScorer {
396+
interface CentroidWChildrenQueryScorer extends CentroidQueryScorer {
379397
int getChildCentroidStart(int centroidOrdinal) throws IOException;
380398

381399
int getChildCount(int centroidOrdinal) throws IOException;

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO
6868
// partition the space
6969
KMeansIntermediate kMeansIntermediate = clusterAndSplit(vectors, targetSize, 0);
7070
if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) {
71+
// TODO: are we oversampling here??
7172
float f = Math.min((float) samplesPerCluster / targetSize, 1.0f);
7273
int localSampleSize = (int) (f * vectors.size());
7374
KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA);

0 commit comments

Comments
 (0)