Skip to content

Commit 47395ff

Browse files
authored
Only collect bulk scored vectors when exceeding min competitive (#132293)
We should not bother collecting vectors that are not competitive. This PR adjusts the scoring interfaces to include the `max` score returned from the block of scored vectors. Then, we will attempt to collect that block if the max score of that block is competitive. This gives a nice speed improvement when querying many probes.
1 parent 390e4f2 commit 47395ff

File tree

3 files changed

+57
-26
lines changed

3 files changed

+57
-26
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ public float score(
141141
*
142142
* <p>The results are stored in the provided scores array.
143143
*/
144-
public void scoreBulk(
144+
public float scoreBulk(
145145
byte[] q,
146146
float queryLowerInterval,
147147
float queryUpperInterval,
@@ -158,6 +158,7 @@ public void scoreBulk(
158158
targetComponentSums[i] = Short.toUnsignedInt(in.readShort());
159159
}
160160
in.readFloats(additionalCorrections, 0, BULK_SIZE);
161+
float maxScore = Float.NEGATIVE_INFINITY;
161162
for (int i = 0; i < BULK_SIZE; i++) {
162163
scores[i] = score(
163164
queryLowerInterval,
@@ -172,6 +173,10 @@ public void scoreBulk(
172173
additionalCorrections[i],
173174
scores[i]
174175
);
176+
if (scores[i] > maxScore) {
177+
maxScore = scores[i];
178+
}
175179
}
180+
return maxScore;
176181
}
177182
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
352352
}
353353

354354
@Override
355-
public void scoreBulk(
355+
public float scoreBulk(
356356
byte[] q,
357357
float queryLowerInterval,
358358
float queryUpperInterval,
@@ -366,7 +366,7 @@ public void scoreBulk(
366366
// 128 / 8 == 16
367367
if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) {
368368
if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) {
369-
score256Bulk(
369+
return score256Bulk(
370370
q,
371371
queryLowerInterval,
372372
queryUpperInterval,
@@ -376,9 +376,8 @@ public void scoreBulk(
376376
centroidDp,
377377
scores
378378
);
379-
return;
380379
} else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) {
381-
score128Bulk(
380+
return score128Bulk(
382381
q,
383382
queryLowerInterval,
384383
queryUpperInterval,
@@ -388,10 +387,9 @@ public void scoreBulk(
388387
centroidDp,
389388
scores
390389
);
391-
return;
392390
}
393391
}
394-
super.scoreBulk(
392+
return super.scoreBulk(
395393
q,
396394
queryLowerInterval,
397395
queryUpperInterval,
@@ -403,7 +401,7 @@ public void scoreBulk(
403401
);
404402
}
405403

406-
private void score128Bulk(
404+
private float score128Bulk(
407405
byte[] q,
408406
float queryLowerInterval,
409407
float queryUpperInterval,
@@ -420,6 +418,7 @@ private void score128Bulk(
420418
float ay = queryLowerInterval;
421419
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
422420
float y1 = queryComponentSum;
421+
float maxScore = Float.NEGATIVE_INFINITY;
423422
for (; i < limit; i += FLOAT_SPECIES_128.length()) {
424423
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
425424
var lx = FloatVector.fromMemorySegment(
@@ -453,6 +452,7 @@ private void score128Bulk(
453452
if (similarityFunction == EUCLIDEAN) {
454453
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
455454
res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0);
455+
maxScore = res.reduceLanes(VectorOperators.MAX);
456456
res.intoArray(scores, i);
457457
} else {
458458
// For cosine and max inner product, we need to apply the additional correction, which is
@@ -463,17 +463,20 @@ private void score128Bulk(
463463
// not sure how to do it better
464464
for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) {
465465
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
466+
maxScore = Math.max(maxScore, scores[i + j]);
466467
}
467468
} else {
468469
res = res.add(1f).mul(0.5f).max(0);
469470
res.intoArray(scores, i);
471+
maxScore = res.reduceLanes(VectorOperators.MAX);
470472
}
471473
}
472474
}
473475
in.seek(offset + 14L * BULK_SIZE);
476+
return maxScore;
474477
}
475478

476-
private void score256Bulk(
479+
private float score256Bulk(
477480
byte[] q,
478481
float queryLowerInterval,
479482
float queryUpperInterval,
@@ -490,6 +493,7 @@ private void score256Bulk(
490493
float ay = queryLowerInterval;
491494
float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE;
492495
float y1 = queryComponentSum;
496+
float maxScore = Float.NEGATIVE_INFINITY;
493497
for (; i < limit; i += FLOAT_SPECIES_256.length()) {
494498
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
495499
var lx = FloatVector.fromMemorySegment(
@@ -523,6 +527,7 @@ private void score256Bulk(
523527
if (similarityFunction == EUCLIDEAN) {
524528
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
525529
res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0);
530+
maxScore = res.reduceLanes(VectorOperators.MAX);
526531
res.intoArray(scores, i);
527532
} else {
528533
// For cosine and max inner product, we need to apply the additional correction, which is
@@ -533,13 +538,16 @@ private void score256Bulk(
533538
// not sure how to do it better
534539
for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) {
535540
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
541+
maxScore = Math.max(maxScore, scores[i + j]);
536542
}
537543
} else {
538544
res = res.add(1f).mul(0.5f).max(0);
545+
maxScore = res.reduceLanes(VectorOperators.MAX);
539546
res.intoArray(scores, i);
540547
}
541548
}
542549
}
543550
in.seek(offset + 14L * BULK_SIZE);
551+
return maxScore;
544552
}
545553
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,8 @@ public int resetPostingsScorer(long offset) throws IOException {
372372
return vectors;
373373
}
374374

375-
void scoreIndividually(int offset) throws IOException {
375+
float scoreIndividually(int offset) throws IOException {
376+
float maxScore = Float.NEGATIVE_INFINITY;
376377
// score individually, first the quantized byte chunk
377378
for (int j = 0; j < BULK_SIZE; j++) {
378379
int doc = docIdsScratch[j + offset];
@@ -407,8 +408,35 @@ void scoreIndividually(int offset) throws IOException {
407408
correctionsAdd[j],
408409
scores[j]
409410
);
411+
if (scores[j] > maxScore) {
412+
maxScore = scores[j];
413+
}
414+
}
415+
}
416+
return maxScore;
417+
}
418+
419+
private static int filterDocs(int[] docIds, int offset, IntPredicate needsScoring) {
420+
int filtered = 0;
421+
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
422+
if (needsScoring.test(docIds[offset + i]) == false) {
423+
docIds[offset + i] = -1;
424+
filtered++;
425+
}
426+
}
427+
return filtered;
428+
}
429+
430+
private static int collect(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) {
431+
int scoredDocs = 0;
432+
for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) {
433+
int doc = docIds[offset + i];
434+
if (doc != -1) {
435+
scoredDocs++;
436+
knnCollector.collect(doc, scores[i]);
410437
}
411438
}
439+
return scoredDocs;
412440
}
413441

414442
@Override
@@ -418,23 +446,17 @@ public int visit(KnnCollector knnCollector) throws IOException {
418446
int limit = vectors - BULK_SIZE + 1;
419447
int i = 0;
420448
for (; i < limit; i += BULK_SIZE) {
421-
int docsToScore = BULK_SIZE;
422-
for (int j = 0; j < BULK_SIZE; j++) {
423-
int doc = docIdsScratch[i + j];
424-
if (needsScoring.test(doc) == false) {
425-
docIdsScratch[i + j] = -1;
426-
docsToScore--;
427-
}
428-
}
449+
int docsToScore = BULK_SIZE - filterDocs(docIdsScratch, i, needsScoring);
429450
if (docsToScore == 0) {
430451
continue;
431452
}
432453
quantizeQueryIfNecessary();
433454
indexInput.seek(slicePos + i * quantizedByteLength);
455+
float maxScore = Float.NEGATIVE_INFINITY;
434456
if (docsToScore < BULK_SIZE / 2) {
435-
scoreIndividually(i);
457+
maxScore = scoreIndividually(i);
436458
} else {
437-
osqVectorsScorer.scoreBulk(
459+
maxScore = osqVectorsScorer.scoreBulk(
438460
quantizedQueryScratch,
439461
queryCorrections.lowerInterval(),
440462
queryCorrections.upperInterval(),
@@ -445,12 +467,8 @@ public int visit(KnnCollector knnCollector) throws IOException {
445467
scores
446468
);
447469
}
448-
for (int j = 0; j < BULK_SIZE; j++) {
449-
int doc = docIdsScratch[i + j];
450-
if (doc != -1) {
451-
scoredDocs++;
452-
knnCollector.collect(doc, scores[j]);
453-
}
470+
if (knnCollector.minCompetitiveSimilarity() < maxScore) {
471+
scoredDocs += collect(docIdsScratch, i, knnCollector, scores);
454472
}
455473
}
456474
// process tail

0 commit comments

Comments
 (0)