diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java index 58df8bb03e0cb..c9ea4f255acce 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java @@ -141,7 +141,7 @@ public float score( * *

The results are stored in the provided scores array. */ - public void scoreBulk( + public float scoreBulk( byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -158,6 +158,7 @@ public void scoreBulk( targetComponentSums[i] = Short.toUnsignedInt(in.readShort()); } in.readFloats(additionalCorrections, 0, BULK_SIZE); + float maxScore = Float.NEGATIVE_INFINITY; for (int i = 0; i < BULK_SIZE; i++) { scores[i] = score( queryLowerInterval, @@ -172,6 +173,10 @@ public void scoreBulk( additionalCorrections[i], scores[i] ); + if (scores[i] > maxScore) { + maxScore = scores[i]; + } } + return maxScore; } } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 4b899cf987600..bf5062e7b1fb6 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -352,7 +352,7 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO } @Override - public void scoreBulk( + public float scoreBulk( byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -366,7 +366,7 @@ public void scoreBulk( // 128 / 8 == 16 if (length >= 16 && PanamaESVectorUtilSupport.HAS_FAST_INTEGER_VECTORS) { if (PanamaESVectorUtilSupport.VECTOR_BITSIZE >= 256) { - score256Bulk( + return score256Bulk( q, queryLowerInterval, queryUpperInterval, @@ -376,9 +376,8 @@ public void scoreBulk( centroidDp, scores ); - return; } else if (PanamaESVectorUtilSupport.VECTOR_BITSIZE == 128) { - score128Bulk( + return score128Bulk( q, queryLowerInterval, queryUpperInterval, @@ -388,10 +387,9 @@ public void scoreBulk( centroidDp, scores ); - return; } } - super.scoreBulk( + return super.scoreBulk( q, queryLowerInterval, queryUpperInterval, @@ -403,7 +401,7 @@ public void scoreBulk( ); } - private void score128Bulk( + private float score128Bulk( byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -420,6 +418,7 @@ private void score128Bulk( float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_128.length()) { var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_128, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( @@ -453,6 +452,7 @@ private void score128Bulk( if (similarityFunction == EUCLIDEAN) { res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); res = FloatVector.broadcast(FLOAT_SPECIES_128, 1).div(res).max(0); + maxScore = res.reduceLanes(VectorOperators.MAX); res.intoArray(scores, i); } else { // For cosine and max inner product, we need to apply the additional correction, which is @@ -463,17 +463,20 @@ private void score128Bulk( // not sure how to do it better for (int j = 0; j < FLOAT_SPECIES_128.length(); j++) { scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + maxScore = Math.max(maxScore, scores[i + j]); } } else { res = res.add(1f).mul(0.5f).max(0); res.intoArray(scores, i); + maxScore = res.reduceLanes(VectorOperators.MAX); } } } in.seek(offset + 14L * BULK_SIZE); + return maxScore; } - private void score256Bulk( + private float score256Bulk( byte[] q, float queryLowerInterval, float queryUpperInterval, @@ -490,6 +493,7 @@ private void score256Bulk( float ay = queryLowerInterval; float ly = (queryUpperInterval - ay) * FOUR_BIT_SCALE; float y1 = queryComponentSum; + float maxScore = Float.NEGATIVE_INFINITY; for (; i < limit; i += FLOAT_SPECIES_256.length()) { var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES_256, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN); var lx = FloatVector.fromMemorySegment( @@ -523,6 +527,7 @@ private void score256Bulk( if (similarityFunction == EUCLIDEAN) { res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f); res = FloatVector.broadcast(FLOAT_SPECIES_256, 1).div(res).max(0); + maxScore = res.reduceLanes(VectorOperators.MAX); res.intoArray(scores, i); } else { // For cosine and max inner product, we need to apply the additional correction, which is @@ -533,13 +538,16 @@ private void score256Bulk( // not sure how to do it better for (int j = 0; j < FLOAT_SPECIES_256.length(); j++) { scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]); + maxScore = Math.max(maxScore, scores[i + j]); } } else { res = res.add(1f).mul(0.5f).max(0); + maxScore = res.reduceLanes(VectorOperators.MAX); res.intoArray(scores, i); } } } in.seek(offset + 14L * BULK_SIZE); + return maxScore; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 7f11698661206..39daa3c64dc12 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -372,7 +372,8 @@ public int resetPostingsScorer(long offset) throws IOException { return vectors; } - void scoreIndividually(int offset) throws IOException { + float scoreIndividually(int offset) throws IOException { + float maxScore = Float.NEGATIVE_INFINITY; // score individually, first the quantized byte chunk for (int j = 0; j < BULK_SIZE; j++) { int doc = docIdsScratch[j + offset]; @@ -407,8 +408,35 @@ void scoreIndividually(int offset) throws IOException { correctionsAdd[j], scores[j] ); + if (scores[j] > maxScore) { + maxScore = scores[j]; + } + } + } + return maxScore; + } + + private static int filterDocs(int[] docIds, int offset, IntPredicate needsScoring) { + int filtered = 0; + for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { + if (needsScoring.test(docIds[offset + i]) == false) { + docIds[offset + i] = -1; + filtered++; + } + } + return filtered; + } + + private static int collect(int[] docIds, int offset, KnnCollector knnCollector, float[] scores) { + int scoredDocs = 0; + for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { + int doc = docIds[offset + i]; + if (doc != -1) { + scoredDocs++; + knnCollector.collect(doc, scores[i]); } } + return scoredDocs; } @Override @@ -418,23 +446,17 @@ public int visit(KnnCollector knnCollector) throws IOException { int limit = vectors - BULK_SIZE + 1; int i = 0; for (; i < limit; i += BULK_SIZE) { - int docsToScore = BULK_SIZE; - for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[i + j]; - if (needsScoring.test(doc) == false) { - docIdsScratch[i + j] = -1; - docsToScore--; - } - } + int docsToScore = BULK_SIZE - filterDocs(docIdsScratch, i, needsScoring); if (docsToScore == 0) { continue; } quantizeQueryIfNecessary(); indexInput.seek(slicePos + i * quantizedByteLength); + float maxScore = Float.NEGATIVE_INFINITY; if (docsToScore < BULK_SIZE / 2) { - scoreIndividually(i); + maxScore = scoreIndividually(i); } else { - osqVectorsScorer.scoreBulk( + maxScore = osqVectorsScorer.scoreBulk( quantizedQueryScratch, queryCorrections.lowerInterval(), queryCorrections.upperInterval(), @@ -445,12 +467,8 @@ public int visit(KnnCollector knnCollector) throws IOException { scores ); } - for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[i + j]; - if (doc != -1) { - scoredDocs++; - knnCollector.collect(doc, scores[j]); - } + if (knnCollector.minCompetitiveSimilarity() < maxScore) { + scoredDocs += collect(docIdsScratch, i, knnCollector, scores); } } // process tail