Skip to content

Commit 4eaa2ea

Browse files
authored
Replace CentroidQueryScorer with CentroidIterator (#131824)
This new interface hides the centroid scoring strategy from the base class.
1 parent 4dba49d commit 4eaa2ea

File tree

2 files changed

+68
-106
lines changed

2 files changed

+68
-106
lines changed

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 55 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect
4848
}
4949

5050
@Override
51-
CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
51+
CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery)
5252
throws IOException {
5353
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
5454
final float globalCentroidDp = fieldEntry.globalCentroidDp();
@@ -65,90 +65,68 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
6565
quantized[i] = (byte) scratch[i];
6666
}
6767
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
68-
return new CentroidQueryScorer() {
69-
int currentCentroid = -1;
70-
long postingListOffset;
71-
private final float[] centroidCorrectiveValues = new float[3];
72-
private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES
73-
+ Short.BYTES);
74-
68+
NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true);
69+
centroids.seek(0L);
70+
final float[] centroidCorrectiveValues = new float[3];
71+
for (int i = 0; i < numCentroids; i++) {
72+
final float qcDist = scorer.int4DotProduct(quantized);
73+
centroids.readFloats(centroidCorrectiveValues, 0, 3);
74+
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
75+
float score = int4QuantizedScore(
76+
qcDist,
77+
queryParams,
78+
fieldInfo.getVectorDimension(),
79+
centroidCorrectiveValues,
80+
quantizedCentroidComponentSum,
81+
globalCentroidDp,
82+
fieldInfo.getVectorSimilarityFunction()
83+
);
84+
queue.add(i, score);
85+
}
86+
final long offset = centroids.getFilePointer();
87+
return new CentroidIterator() {
7588
@Override
76-
public int size() {
77-
return numCentroids;
89+
public boolean hasNext() {
90+
return queue.size() > 0;
7891
}
7992

8093
@Override
81-
public long postingListOffset(int centroidOrdinal) throws IOException {
82-
if (centroidOrdinal != currentCentroid) {
83-
centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal);
84-
postingListOffset = centroids.readLong();
85-
currentCentroid = centroidOrdinal;
86-
}
87-
return postingListOffset;
88-
}
89-
90-
public void bulkScore(NeighborQueue queue) throws IOException {
91-
// TODO: bulk score centroids like we do with posting lists
92-
centroids.seek(0L);
93-
for (int i = 0; i < numCentroids; i++) {
94-
queue.add(i, score());
95-
}
96-
}
97-
98-
private float score() throws IOException {
99-
final float qcDist = scorer.int4DotProduct(quantized);
100-
centroids.readFloats(centroidCorrectiveValues, 0, 3);
101-
final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort());
102-
return int4QuantizedScore(
103-
qcDist,
104-
queryParams,
105-
fieldInfo.getVectorDimension(),
106-
centroidCorrectiveValues,
107-
quantizedCentroidComponentSum,
108-
globalCentroidDp,
109-
fieldInfo.getVectorSimilarityFunction()
110-
);
111-
}
112-
113-
// TODO can we do this in off-heap blocks?
114-
private float int4QuantizedScore(
115-
float qcDist,
116-
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
117-
int dims,
118-
float[] targetCorrections,
119-
int targetComponentSum,
120-
float centroidDp,
121-
VectorSimilarityFunction similarityFunction
122-
) {
123-
float ax = targetCorrections[0];
124-
// Here we assume `lx` is simply bit vectors, so the scaling isn't necessary
125-
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
126-
float ay = queryCorrections.lowerInterval();
127-
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
128-
float y1 = queryCorrections.quantizedComponentSum();
129-
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
130-
if (similarityFunction == EUCLIDEAN) {
131-
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
132-
return Math.max(1 / (1f + score), 0);
133-
} else {
134-
// For cosine and max inner product, we need to apply the additional correction, which is
135-
// assumed to be the non-centered dot-product between the vector and the centroid
136-
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
137-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
138-
return VectorUtil.scaleMaxInnerProductScore(score);
139-
}
140-
return Math.max((1f + score) / 2f, 0);
141-
}
94+
public long nextPostingListOffset() throws IOException {
95+
int centroidOrdinal = queue.pop();
96+
centroids.seek(offset + (long) Long.BYTES * centroidOrdinal);
97+
return centroids.readLong();
14298
}
14399
};
144100
}
145101

146-
@Override
147-
NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe)
148-
throws IOException {
149-
NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true);
150-
centroidQueryScorer.bulkScore(neighborQueue);
151-
return neighborQueue;
102+
// TODO can we do this in off-heap blocks?
103+
private float int4QuantizedScore(
104+
float qcDist,
105+
OptimizedScalarQuantizer.QuantizationResult queryCorrections,
106+
int dims,
107+
float[] targetCorrections,
108+
int targetComponentSum,
109+
float centroidDp,
110+
VectorSimilarityFunction similarityFunction
111+
) {
112+
float ax = targetCorrections[0];
113+
float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE;
114+
float ay = queryCorrections.lowerInterval();
115+
float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE;
116+
float y1 = queryCorrections.quantizedComponentSum();
117+
float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist;
118+
if (similarityFunction == EUCLIDEAN) {
119+
score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score;
120+
return Math.max(1 / (1f + score), 0);
121+
} else {
122+
// For cosine and max inner product, we need to apply the additional correction, which is
123+
// assumed to be the non-centered dot-product between the vector and the centroid
124+
score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp;
125+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
126+
return VectorUtil.scaleMaxInnerProductScore(score);
127+
}
128+
return Math.max((1f + score) / 2f, 0);
129+
}
152130
}
153131

154132
@Override

server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java

Lines changed: 13 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
import org.apache.lucene.util.BitSet;
3232
import org.apache.lucene.util.Bits;
3333
import org.apache.lucene.util.FixedBitSet;
34-
import org.apache.lucene.util.hnsw.NeighborQueue;
3534
import org.elasticsearch.core.IOUtils;
3635
import org.elasticsearch.search.vectors.IVFKnnSearchStrategy;
3736

@@ -89,7 +88,7 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR
8988
}
9089
}
9190

92-
abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
91+
abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target)
9392
throws IOException;
9493

9594
private static IndexInput openDataInput(
@@ -236,22 +235,16 @@ public final void search(String field, float[] target, KnnCollector knnCollector
236235
}
237236

238237
FieldEntry entry = fields.get(fieldInfo.number);
239-
CentroidQueryScorer centroidQueryScorer = getCentroidScorer(
240-
fieldInfo,
241-
entry.numCentroids,
242-
entry.centroidSlice(ivfCentroids),
243-
target
244-
);
245238
if (nProbe == DYNAMIC_NPROBE) {
246239
// empirically based, and a good dynamic to get decent recall while scaling a la "efSearch"
247240
// scaling by the number of centroids vs. the nearest neighbors requested
248241
// not perfect, but a comparative heuristic.
249242
// we might want to utilize the total vector count as well, but this is a good start
250-
nProbe = (int) Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k()));
243+
nProbe = (int) Math.round(Math.log10(entry.numCentroids) * Math.sqrt(knnCollector.k()));
251244
// clip to be between 1 and the number of centroids
252-
nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1);
245+
nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1);
253246
}
254-
final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe);
247+
CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target);
255248
PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring);
256249
int centroidsVisited = 0;
257250
long expectedDocs = 0;
@@ -260,22 +253,22 @@ public final void search(String field, float[] target, KnnCollector knnCollector
260253
// Note, numCollected is doing the bare minimum here.
261254
// TODO do we need to handle nested doc counts similarly to how we handle
262255
// filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors?
263-
while (centroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
256+
while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) {
264257
++centroidsVisited;
265258
// todo do we actually need to know the score???
266-
int centroidOrdinal = centroidQueue.pop();
259+
long offset = centroidIterator.nextPostingListOffset();
267260
// todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing
268261
// is enough?
269-
expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
262+
expectedDocs += scorer.resetPostingsScorer(offset);
270263
actualDocs += scorer.visit(knnCollector);
271264
}
272265
if (acceptDocs != null) {
273266
float unfilteredRatioVisited = (float) expectedDocs / numVectors;
274267
int filteredVectors = (int) Math.ceil(numVectors * percentFiltered);
275268
float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f);
276-
while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
277-
int centroidOrdinal = centroidQueue.pop();
278-
scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal));
269+
while (centroidIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) {
270+
long offset = centroidIterator.nextPostingListOffset();
271+
scorer.resetPostingsScorer(offset);
279272
actualDocs += scorer.visit(knnCollector);
280273
}
281274
}
@@ -294,13 +287,6 @@ public final void search(String field, byte[] target, KnnCollector knnCollector,
294287
}
295288
}
296289

297-
abstract NeighborQueue scorePostingLists(
298-
FieldInfo fieldInfo,
299-
KnnCollector knnCollector,
300-
CentroidQueryScorer centroidQueryScorer,
301-
int nProbe
302-
) throws IOException;
303-
304290
@Override
305291
public void close() throws IOException {
306292
IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters);
@@ -323,12 +309,10 @@ IndexInput centroidSlice(IndexInput centroidFile) throws IOException {
323309
abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring)
324310
throws IOException;
325311

326-
interface CentroidQueryScorer {
327-
int size();
328-
329-
long postingListOffset(int centroidOrdinal) throws IOException;
312+
interface CentroidIterator {
313+
boolean hasNext();
330314

331-
void bulkScore(NeighborQueue queue) throws IOException;
315+
long nextPostingListOffset() throws IOException;
332316
}
333317

334318
interface PostingVisitor {

0 commit comments

Comments
 (0)