diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 304cc57284227..47c6bb99eabb5 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -48,7 +48,7 @@ public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVect } @Override - CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) + CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] targetQuery) throws IOException { final FieldEntry fieldEntry = fields.get(fieldInfo.number); final float globalCentroidDp = fieldEntry.globalCentroidDp(); @@ -65,90 +65,68 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind quantized[i] = (byte) scratch[i]; } final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); - return new CentroidQueryScorer() { - int currentCentroid = -1; - long postingListOffset; - private final float[] centroidCorrectiveValues = new float[3]; - private final long quantizeCentroidsLength = (long) numCentroids * (fieldInfo.getVectorDimension() + 3 * Float.BYTES - + Short.BYTES); - + NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true); + centroids.seek(0L); + final float[] centroidCorrectiveValues = new float[3]; + for (int i = 0; i < numCentroids; i++) { + final float qcDist = scorer.int4DotProduct(quantized); + centroids.readFloats(centroidCorrectiveValues, 0, 3); + final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort()); + float score = int4QuantizedScore( + qcDist, + queryParams, + fieldInfo.getVectorDimension(), + centroidCorrectiveValues, + quantizedCentroidComponentSum, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction() + ); + queue.add(i, score); + } + final long offset = centroids.getFilePointer(); + return new CentroidIterator() { @Override - public int size() { - return numCentroids; + public boolean hasNext() { + return queue.size() > 0; } @Override - public long postingListOffset(int centroidOrdinal) throws IOException { - if (centroidOrdinal != currentCentroid) { - centroids.seek(quantizeCentroidsLength + (long) Long.BYTES * centroidOrdinal); - postingListOffset = centroids.readLong(); - currentCentroid = centroidOrdinal; - } - return postingListOffset; - } - - public void bulkScore(NeighborQueue queue) throws IOException { - // TODO: bulk score centroids like we do with posting lists - centroids.seek(0L); - for (int i = 0; i < numCentroids; i++) { - queue.add(i, score()); - } - } - - private float score() throws IOException { - final float qcDist = scorer.int4DotProduct(quantized); - centroids.readFloats(centroidCorrectiveValues, 0, 3); - final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort()); - return int4QuantizedScore( - qcDist, - queryParams, - fieldInfo.getVectorDimension(), - centroidCorrectiveValues, - quantizedCentroidComponentSum, - globalCentroidDp, - fieldInfo.getVectorSimilarityFunction() - ); - } - - // TODO can we do this in off-heap blocks? - private float int4QuantizedScore( - float qcDist, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, - int dims, - float[] targetCorrections, - int targetComponentSum, - float centroidDp, - VectorSimilarityFunction similarityFunction - ) { - float ax = targetCorrections[0]; - // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary - float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); - float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; - if (similarityFunction == EUCLIDEAN) { - score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; - return Math.max(1 / (1f + score), 0); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - return VectorUtil.scaleMaxInnerProductScore(score); - } - return Math.max((1f + score) / 2f, 0); - } + public long nextPostingListOffset() throws IOException { + int centroidOrdinal = queue.pop(); + centroids.seek(offset + (long) Long.BYTES * centroidOrdinal); + return centroids.readLong(); } }; } - @Override - NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe) - throws IOException { - NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true); - centroidQueryScorer.bulkScore(neighborQueue); - return neighborQueue; + // TODO can we do this in off-heap blocks? + private float int4QuantizedScore( + float qcDist, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + int dims, + float[] targetCorrections, + int targetComponentSum, + float centroidDp, + VectorSimilarityFunction similarityFunction + ) { + float ax = targetCorrections[0]; + float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 01cced04a9fcc..b570bd83f36e4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -31,7 +31,6 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.hnsw.NeighborQueue; import org.elasticsearch.core.IOUtils; import org.elasticsearch.search.vectors.IVFKnnSearchStrategy; @@ -89,7 +88,7 @@ protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsR } } - abstract CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target) + abstract CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, IndexInput centroids, float[] target) throws IOException; private static IndexInput openDataInput( @@ -236,22 +235,16 @@ public final void search(String field, float[] target, KnnCollector knnCollector } FieldEntry entry = fields.get(fieldInfo.number); - CentroidQueryScorer centroidQueryScorer = getCentroidScorer( - fieldInfo, - entry.numCentroids, - entry.centroidSlice(ivfCentroids), - target - ); if (nProbe == DYNAMIC_NPROBE) { // empirically based, and a good dynamic to get decent recall while scaling a la "efSearch" // scaling by the number of centroids vs. the nearest neighbors requested // not perfect, but a comparative heuristic. // we might want to utilize the total vector count as well, but this is a good start - nProbe = (int) Math.round(Math.log10(centroidQueryScorer.size()) * Math.sqrt(knnCollector.k())); + nProbe = (int) Math.round(Math.log10(entry.numCentroids) * Math.sqrt(knnCollector.k())); // clip to be between 1 and the number of centroids - nProbe = Math.max(Math.min(nProbe, centroidQueryScorer.size()), 1); + nProbe = Math.max(Math.min(nProbe, entry.numCentroids), 1); } - final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe); + CentroidIterator centroidIterator = getCentroidIterator(fieldInfo, entry.numCentroids, entry.centroidSlice(ivfCentroids), target); PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); int centroidsVisited = 0; long expectedDocs = 0; @@ -260,22 +253,22 @@ public final void search(String field, float[] target, KnnCollector knnCollector // Note, numCollected is doing the bare minimum here. // TODO do we need to handle nested doc counts similarly to how we handle // filtering? E.g. keep exploring until we hit an expected number of parent documents vs. child vectors? - while (centroidQueue.size() > 0 && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) { + while (centroidIterator.hasNext() && (centroidsVisited < nProbe || knnCollectorImpl.numCollected() < knnCollector.k())) { ++centroidsVisited; // todo do we actually need to know the score??? - int centroidOrdinal = centroidQueue.pop(); + long offset = centroidIterator.nextPostingListOffset(); // todo do we need direct access to the raw centroid???, this is used for quantizing, maybe hydrating and quantizing // is enough? - expectedDocs += scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal)); + expectedDocs += scorer.resetPostingsScorer(offset); actualDocs += scorer.visit(knnCollector); } if (acceptDocs != null) { float unfilteredRatioVisited = (float) expectedDocs / numVectors; int filteredVectors = (int) Math.ceil(numVectors * percentFiltered); float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); - while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { - int centroidOrdinal = centroidQueue.pop(); - scorer.resetPostingsScorer(centroidQueryScorer.postingListOffset(centroidOrdinal)); + while (centroidIterator.hasNext() && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { + long offset = centroidIterator.nextPostingListOffset(); + scorer.resetPostingsScorer(offset); actualDocs += scorer.visit(knnCollector); } } @@ -294,13 +287,6 @@ public final void search(String field, byte[] target, KnnCollector knnCollector, } } - abstract NeighborQueue scorePostingLists( - FieldInfo fieldInfo, - KnnCollector knnCollector, - CentroidQueryScorer centroidQueryScorer, - int nProbe - ) throws IOException; - @Override public void close() throws IOException { IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters); @@ -323,12 +309,10 @@ IndexInput centroidSlice(IndexInput centroidFile) throws IOException { abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) throws IOException; - interface CentroidQueryScorer { - int size(); - - long postingListOffset(int centroidOrdinal) throws IOException; + interface CentroidIterator { + boolean hasNext(); - void bulkScore(NeighborQueue queue) throws IOException; + long nextPostingListOffset() throws IOException; } interface PostingVisitor {