From 8b1e3b47a73e506cef274523e2bc99bdb61efc26 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 29 Apr 2025 10:20:46 -0400 Subject: [PATCH 01/11] Adding experimental IVF format --- .../ES91OSQVectorsScorer.java | 2 +- .../elasticsearch/simdvec/ESVectorUtil.java | 7 + .../DefaultESVectorizationProvider.java | 1 + .../ESVectorizationProvider.java | 1 + .../ES91OSQVectorScorerTests.java | 1 + .../vectors/DefaultIVFVectorsReader.java | 430 +++++++++ .../vectors/DefaultIVFVectorsWriter.java | 891 ++++++++++++++++++ .../index/codec/vectors/IVFVectorsFormat.java | 92 ++ .../index/codec/vectors/IVFVectorsReader.java | 474 ++++++++++ .../index/codec/vectors/IVFVectorsWriter.java | 498 ++++++++++ .../index/codec/vectors/NeighborQueue.java | 162 ++++ 11 files changed, 2558 insertions(+), 1 deletion(-) rename libs/simdvec/src/main/java/org/elasticsearch/simdvec/{internal/vectorization => }/ES91OSQVectorsScorer.java (99%) create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java create mode 100644 server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java similarity index 99% rename from libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java rename to libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java index 839e5f29a1148..be55c48dbe441 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES91OSQVectorsScorer.java @@ -6,7 +6,7 @@ * your election, the "Elastic License 2.0", the "GNU Affero General Public * License v3.0 only", or the "Server Side Public License, v 1". */ -package org.elasticsearch.simdvec.internal.vectorization; +package org.elasticsearch.simdvec; import org.apache.lucene.index.VectorSimilarityFunction; import org.apache.lucene.store.IndexInput; diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 41bf6ff58d144..9212d5c83bd6a 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -9,11 +9,13 @@ package org.elasticsearch.simdvec; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; import org.elasticsearch.simdvec.internal.vectorization.ESVectorUtilSupport; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; +import java.io.IOException; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; import java.lang.invoke.MethodType; @@ -41,6 +43,11 @@ public class ESVectorUtil { private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); + + public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { + return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension); + } + public static long ipByteBinByte(byte[] q, byte[] d) { if (q.length != d.length * B_QUERY) { throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + B_QUERY + " x " + d.length); diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java index e8ff6f83f2172..51a78d3cd6c37 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorizationProvider.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index 7f4e62f156a36..8c040484c7c03 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -10,6 +10,7 @@ package org.elasticsearch.simdvec.internal.vectorization; import org.apache.lucene.store.IndexInput; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.util.Objects; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java index 53b14ae4910c0..5544c0686fa5f 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java @@ -16,6 +16,7 @@ import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import static org.hamcrest.Matchers.lessThan; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java new file mode 100644 index 0000000000000..7524ae3558faa --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -0,0 +1,430 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; +import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; + +import java.io.IOException; +import java.util.function.IntPredicate; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.ArrayUtil; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.NeighborQueue; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ESVectorUtil; + +/** + * Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using + * brute force and then scores the top ones using the posting list. + */ +public class DefaultIVFVectorsReader extends IVFVectorsReader { + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + super(state, rawVectorsReader); + } + + @Override + protected CentroidQueryScorer getCentroidScorer( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] targetQuery, + IndexInput clusters) + throws IOException { + FieldEntry fieldEntry = fields.get(fieldInfo.number); + float[] globalCentroid = fieldEntry.globalCentroid(); + float globalCentroidDp = fieldEntry.globalCentroidDp(); + OptimizedScalarQuantizer scalarQuantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + byte[] quantized = new byte[targetQuery.length]; + float[] targetScratch = ArrayUtil.copyArray(targetQuery); + OptimizedScalarQuantizer.QuantizationResult queryParams = + scalarQuantizer.scalarQuantize(targetScratch, quantized, (byte) 4, globalCentroid); + return new CentroidQueryScorer() { + int currentCentroid = -1; + private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()]; + private final float[] centroid = new float[fieldInfo.getVectorDimension()]; + private final float[] centroidCorrectiveValues = new float[3]; + private int quantizedCentroidComponentSum; + private final long centroidByteSize = + fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + + @Override + public int size() { + return numCentroids; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + readQuantizedCentroid(centroidOrdinal); + return centroid; + } + + private void readQuantizedCentroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currentCentroid) { + return; + } + centroids.seek(centroidOrdinal * centroidByteSize); + quantizedCentroidComponentSum = + readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues); + centroids.seek( + numCentroids * centroidByteSize + + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal); + centroids.readFloats(centroid, 0, centroid.length); + currentCentroid = centroidOrdinal; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + readQuantizedCentroid(centroidOrdinal); + return int4QuantizedScore( + quantized, + queryParams, + fieldInfo.getVectorDimension(), + quantizedCentroid, + centroidCorrectiveValues, + quantizedCentroidComponentSum, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction()); + } + }; + } + + @Override + protected FloatVectorValues getCentroids( + IndexInput indexInput, int numCentroids, FieldInfo info) { + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + return new OffHeapCentroidFloatVectorValues( + numCentroids, indexInput, info.getVectorDimension()); + } + + @Override + NeighborQueue scorePostingLists( + FieldInfo fieldInfo, + KnnCollector knnCollector, + CentroidQueryScorer centroidQueryScorer, + int nProbe) + throws IOException { + NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true); + // TODO Off heap scoring for quantized centroids? + for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) { + neighborQueue.add(centroid, centroidQueryScorer.score(centroid)); + } + return neighborQueue; + } + + @Override + PostingVisitor getPostingVisitor( + FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) + throws IOException { + FieldEntry entry = fields.get(fieldInfo.number); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring); + } + + // TODO can we do this in off-heap blocks? + static float int4QuantizedScore( + byte[] quantizedQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + int dims, + byte[] binaryCode, + float[] targetCorrections, + int targetComponentSum, + float centroidDp, + VectorSimilarityFunction similarityFunction) { + float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode); + float ax = targetCorrections[0]; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = + ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } + } + + static class OffHeapCentroidFloatVectorValues extends FloatVectorValues { + private final int numCentroids; + private final IndexInput input; + private final int dimension; + private final float[] centroid; + private final long centroidByteSize; + private int ord = -1; + + OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) { + this.numCentroids = numCentroids; + this.input = input; + this.dimension = dimension; + this.centroid = new float[dimension]; + this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + if (ord < 0 || ord >= numCentroids) { + throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]"); + } + if (ord == this.ord) { + return centroid; + } + readQuantizedCentroid(ord); + return centroid; + } + + private void readQuantizedCentroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == ord) { + return; + } + input.seek( + numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal); + input.readFloats(centroid, 0, centroid.length); + ord = centroidOrdinal; + } + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return numCentroids; + } + + @Override + public FloatVectorValues copy() throws IOException { + return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension); + } + } + + private static class MemorySegmentPostingsVisitor implements PostingVisitor { + final long quantizedByteLength; + final IndexInput indexInput; + final float[] target; + final FieldEntry entry; + final FieldInfo fieldInfo; + final IntPredicate needsScoring; + private final ES91OSQVectorsScorer osqVectorsScorer; + final float[] scores = new float[BULK_SIZE]; + final float[] correctionsLower = new float[BULK_SIZE]; + final float[] correctionsUpper = new float[BULK_SIZE]; + final int[] correctionsSum = new int[BULK_SIZE]; + final float[] correctionsAdd = new float[BULK_SIZE]; + + int[] docIdsScratch = new int[0]; + int vectors; + boolean quantized = false; + float centroidDp; + float[] centroid; + long slicePos; + OptimizedScalarQuantizer.QuantizationResult queryCorrections; + DocIdsWriter docIdsWriter = new DocIdsWriter(); + + final float[] scratch; + final byte[] quantizationScratch; + final byte[] quantizedQueryScratch; + final OptimizedScalarQuantizer quantizer; + final float[] correctiveValues = new float[3]; + final long quantizedVectorByteSize; + + MemorySegmentPostingsVisitor( + float[] target, + IndexInput indexInput, + FieldEntry entry, + FieldInfo fieldInfo, + IntPredicate needsScoring) + throws IOException { + this.target = target; + this.indexInput = indexInput; + this.entry = entry; + this.fieldInfo = fieldInfo; + this.needsScoring = needsScoring; + + scratch = new float[target.length]; + quantizationScratch = new byte[target.length]; + final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); + quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; + quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; + quantizedVectorByteSize = (discretizedDimensions / 8); + quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + osqVectorsScorer = + ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); + } + + @Override + public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException { + quantized = false; + indexInput.seek(entry.postingListOffsets()[centroidOrdinal]); + vectors = indexInput.readVInt(); + centroidDp = Float.intBitsToFloat(indexInput.readInt()); + this.centroid = centroid; + // read the doc ids + docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; + docIdsWriter.readInts(indexInput, vectors, docIdsScratch); + slicePos = indexInput.getFilePointer(); + return vectors; + } + + void scoreIndividually(int offset) throws IOException { + // score individually, first the quantized byte chunk + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j + offset]; + if (doc != -1) { + indexInput.seek( + slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + scores[j] = qcDist; + } + } + // read in all corrections + indexInput.seek( + slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); + indexInput.readFloats(correctionsLower, 0, BULK_SIZE); + indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); + for (int j = 0; j < BULK_SIZE; j++) { + correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort()); + } + indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); + // Now apply corrections + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[offset + j]; + if (doc != -1) { + scores[j] = + osqVectorsScorer.score( + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctionsLower[j], + correctionsUpper[j], + correctionsSum[j], + correctionsAdd[j], + scores[j]); + } + } + } + + @Override + public int visit(KnnCollector knnCollector) throws IOException { + // block processing + int scoredDocs = 0; + int limit = vectors - BULK_SIZE + 1; + int i = 0; + for (; i < limit; i += BULK_SIZE) { + int docsToScore = BULK_SIZE; + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[i + j]; + if (needsScoring.test(doc) == false) { + docIdsScratch[i + j] = -1; + docsToScore--; + } + } + if (docsToScore == 0) { + continue; + } + quantizeQueryIfNecessary(); + indexInput.seek(slicePos + i * quantizedByteLength); + if (docsToScore < BULK_SIZE / 2) { + scoreIndividually(i); + } else { + osqVectorsScorer.scoreBulk( + quantizedQueryScratch, + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + scores); + } + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[i + j]; + if (doc != -1) { + scoredDocs++; + knnCollector.collect(doc, scores[j]); + } + } + } + // process tail + for (; i < vectors; i++) { + int doc = docIdsScratch[i]; + if (needsScoring.test(doc)) { + quantizeQueryIfNecessary(); + indexInput.seek(slicePos + i * quantizedByteLength); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + indexInput.readFloats(correctiveValues, 0, 3); + final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); + float score = + osqVectorsScorer.score( + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist); + scoredDocs++; + knnCollector.collect(doc, score); + } + } + knnCollector.incVisitedCount(scoredDocs); + return scoredDocs; + } + + private void quantizeQueryIfNecessary() { + if (quantized == false) { + System.arraycopy(target, 0, scratch, 0, target.length); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(scratch); + } + queryCorrections = + quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid); + transposeHalfByte(quantizationScratch, quantizedQueryScratch); + quantized = true; + } + } + } + + static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) + throws IOException { + assert corrections.length == 3; + indexInput.readBytes(binaryValue, 0, binaryValue.length); + corrections[0] = Float.intBitsToFloat(indexInput.readInt()); + corrections[1] = Float.intBitsToFloat(indexInput.readInt()); + corrections[2] = Float.intBitsToFloat(indexInput.readInt()); + return Short.toUnsignedInt(indexInput.readShort()); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java new file mode 100644 index 0000000000000..2a5b66544fbb7 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -0,0 +1,891 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntArrayList; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ESVectorUtil; + +/** + * Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to + * partition the vector space, and then stores the centroids an posting list in a sequential + * fashion. + */ +public class DefaultIVFVectorsWriter extends IVFVectorsWriter { + + static final float SOAR_LAMBDA = 1.0f; + // What percentage of the centroids do we do a second check on for SOAR assignment + static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f; + + private final int vectorPerCluster; + + private final OptimizedScalarQuantizer.QuantizationResult[] corrections = + new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE]; + + public DefaultIVFVectorsWriter( + SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) + throws IOException { + super(state, rawVectorDelegate); + this.vectorPerCluster = vectorPerCluster; + } + + @Override + CentroidAssignmentScorer calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + float[] globalCentroid) + throws IOException { + if (floatVectorValues.size() == 0) { + return CentroidAssignmentScorer.EMPTY; + } + // calculate the centroids + int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + int desiredClusters = + (int) + Math.max( + maxNumClusters / 16.0, + Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters)); + if (floatVectorValues.size() / desiredClusters > vectorPerCluster) { + desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + } + final KMeans.Results kMeans = + KMeans.cluster( + floatVectorValues, + desiredClusters, + false, + 42L, + KMeans.KmeansInitializationMethod.PLUS_PLUS, + null, + fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, + 1, + 15, + desiredClusters * 256); + float[][] centroids = kMeans.centroids(); + // write them + writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); + return new OnHeapCentroidAssignmentScorer(centroids); + } + + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + InfoStream infoStream, + CentroidAssignmentScorer randomCentroidScorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput) + throws IOException { + IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()]; + for (int i = 0; i < randomCentroidScorer.size(); i++) { + clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4); + } + assignCentroids(randomCentroidScorer, floatVectorValues, clusters); + if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(clusters, infoStream); + } + // write the posting lists + final long[] offsets = new long[randomCentroidScorer.size()]; + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + BinarizedFloatVectorValues binarizedByteVectorValues = + new BinarizedFloatVectorValues(floatVectorValues, quantizer); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + for (int i = 0; i < randomCentroidScorer.size(); i++) { + float[] centroid = randomCentroidScorer.centroid(i); + binarizedByteVectorValues.centroid = centroid; + // TODO sort by distance to the centroid + IntArrayList cluster = clusters[i]; + // TODO align??? + offsets[i] = postingsOutput.getFilePointer(); + int size = cluster.size(); + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds( + j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput); + writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + } + return offsets; + } + + private void writePostingList( + IntArrayList cluster, + IndexOutput postingsOutput, + BinarizedFloatVectorValues binarizedByteVectorValues) + throws IOException { + int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1; + int cidx = 0; + // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE. + for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) { + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + int ord = cluster.get(cidx + j); + byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); + // write vector + postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); + corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord); + } + // write corrections + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval())); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval())); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + int targetComponentSum = corrections[j].quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + postingsOutput.writeShort((short) targetComponentSum); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection())); + } + } + // write tail + for (; cidx < cluster.size(); cidx++) { + int ord = cluster.get(cidx); + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult corrections = + binarizedByteVectorValues.getCorrectiveTerms(ord); + writeQuantizedValue(postingsOutput, binaryValue, corrections); + binarizedByteVectorValues.getCorrectiveTerms(ord); + postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); + postingsOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + postingsOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); + postingsOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 + && corrections.quantizedComponentSum() <= 0xffff; + postingsOutput.writeShort((short) corrections.quantizedComponentSum()); + } + } + + @Override + CentroidAssignmentScorer createCentroidScorer( + IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) + throws IOException { + return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo); + } + + static void writeCentroids( + float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) + throws IOException { + final OptimizedScalarQuantizer osq = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()]; + float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; + // TODO do we want to store these distances as well for future use? + float[] distances = new float[centroids.length]; + for (int i = 0; i < centroids.length; i++) { + distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid); + } + // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest + // (largest) + for (int i = 0; i < centroids.length; i++) { + for (int j = i + 1; j < centroids.length; j++) { + if (distances[i] > distances[j]) { + float[] tmp = centroids[i]; + centroids[i] = centroids[j]; + centroids[j] = tmp; + float tmpDistance = distances[i]; + distances[i] = distances[j]; + distances[j] = tmpDistance; + } + } + } + for (float[] centroid : centroids) { + System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); + OptimizedScalarQuantizer.QuantizationResult result = + osq.scalarQuantize(centroidScratch, quantizedScratch, (byte) 4, globalCentroid); + writeQuantizedValue(centroidOutput, quantizedScratch, result); + } + final ByteBuffer buffer = + ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + for (float[] centroid : centroids) { + buffer.asFloatBuffer().put(centroid); + centroidOutput.writeBytes(buffer.array(), buffer.array().length); + } + } + + record SegmentCentroid(int segment, int centroid, int centroidSize) {} + + @Override + protected int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid) + throws IOException { + if (floatVectorValues.size() == 0) { + return 0; + } + int desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + // init centroids from merge state + List centroidList = new ArrayList<>(); + List segmentCentroids = new ArrayList<>(desiredClusters); + + int segmentIdx = 0; + long startTime = System.nanoTime(); + for (var reader : mergeState.knnVectorsReaders) { + IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); + if (ivfVectorsReader == null) { + continue; + } + + FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); + centroidList.add(centroid); + for (int i = 0; i < centroid.size(); i++) { + int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); + segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); + } + segmentIdx++; + } + + // sort centroid list by floatvector size + FloatVectorValues baseSegment = centroidList.get(0); + for (var l : centroidList) { + if (l.size() > baseSegment.size()) { + baseSegment = l; + } + } + float[] scratch = new float[fieldInfo.getVectorDimension()]; + float minimumDistance = Float.MAX_VALUE; + for (int j = 0; j < baseSegment.size(); j++) { + System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension()); + for (int k = j + 1; k < baseSegment.size(); k++) { + float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k)); + if (d < minimumDistance) { + minimumDistance = d; + } + } + } + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Agglomerative cluster min distance: " + + minimumDistance + + " From biggest segment: " + + baseSegment.size()); + } + int[] labels = new int[segmentCentroids.size()]; + // loop over segments + int clusterIdx = 0; + // keep track of all inter-centroid distances, + // using less than centroid * centroid space (e.g. not keeping track of duplicates) + for (int i = 0; i < segmentCentroids.size(); i++) { + if (labels[i] == 0) { + clusterIdx += 1; + labels[i] = clusterIdx; + } + SegmentCentroid segmentCentroid = segmentCentroids.get(i); + System.arraycopy( + centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid), + 0, + scratch, + 0, + baseSegment.dimension()); + for (int j = i + 1; j < segmentCentroids.size(); j++) { + float d = + VectorUtil.squareDistance( + scratch, + centroidList + .get(segmentCentroids.get(j).segment()) + .vectorValue(segmentCentroids.get(j).centroid())); + if (d < minimumDistance / 2) { + if (labels[j] == 0) { + labels[j] = labels[i]; + } else { + for (int k = 0; k < labels.length; k++) { + if (labels[k] == labels[j]) { + labels[k] = labels[i]; + } + } + } + } + } + } + float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()]; + int[] sum = new int[clusterIdx]; + for (int i = 0; i < segmentCentroids.size(); i++) { + SegmentCentroid segmentCentroid = segmentCentroids.get(i); + int label = labels[i]; + FloatVectorValues segment = centroidList.get(segmentCentroid.segment()); + float[] vector = segment.vectorValue(segmentCentroid.centroid); + for (int j = 0; j < vector.length; j++) { + initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize); + } + sum[label - 1] += segmentCentroid.centroidSize; + } + for (int i = 0; i < initCentroids.length; i++) { + for (int j = 0; j < initCentroids[i].length; j++) { + initCentroids[i][j] /= sum[i]; + } + } + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)); + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters); + } + + // FIXME: still split to get to desired cluster count? + // FIXME: need a way to maintain the original mapping ... update KMeans to allow maintaining + // that mapping + // FIXME: go update the assignCentroids code to respect that mapping from prior centroid to next + // centroid (via the scorer?) + // FIXME: run a custom version of kmeans that adjusts the centroids that were split related to + // only the sets of vectors that were previously associated with the prior centroids + // FIXME: compare this kmeans outcome with a lot of iterations with the outcome of the process + // detailed above; ideally a large run of kmeans is approximated by the above algorithm + long nanoTime = System.nanoTime(); + final KMeans.Results kMeans = + KMeans.cluster( + floatVectorValues, + desiredClusters, + false, + 42L, + KMeans.KmeansInitializationMethod.PLUS_PLUS, + initCentroids, + fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, + 1, + 5, + desiredClusters * 64); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); + } + float[][] centroids = kMeans.centroids(); + + // write them + writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput); + return centroids.length; + } + + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidAssignmentScorer centroidAssignmentScorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState) + throws IOException { + IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()]; + for (int i = 0; i < centroidAssignmentScorer.size(); i++) { + clusters[i] = + new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4); + } + long nanoTime = System.nanoTime(); + assignCentroidsMerge( + centroidAssignmentScorer, floatVectorValues, mergeState, fieldInfo.name, clusters); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); + } + + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(clusters, mergeState.infoStream); + } + // write the posting lists + final long[] offsets = new long[centroidAssignmentScorer.size()]; + OptimizedScalarQuantizer quantizer = + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + BinarizedFloatVectorValues binarizedByteVectorValues = + new BinarizedFloatVectorValues(floatVectorValues, quantizer); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + for (int i = 0; i < centroidAssignmentScorer.size(); i++) { + float[] centroid = centroidAssignmentScorer.centroid(i); + binarizedByteVectorValues.centroid = centroid; + // TODO: sort by distance to the centroid + IntArrayList cluster = clusters[i]; + // TODO align??? + offsets[i] = postingsOutput.getFilePointer(); + int size = cluster.size(); + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds( + j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); + writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + } + return offsets; + } + + private static void printClusterQualityStatistics( + IntArrayList[] clusters, InfoStream infoStream) { + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + float mean = 0; + float m2 = 0; + // iteratively compute the variance & mean + int count = 0; + for (IntArrayList cluster : clusters) { + count += 1; + if (cluster == null) { + continue; + } + float delta = cluster.size() - mean; + mean += delta / count; + m2 += delta * (cluster.size() - mean); + min = Math.min(min, cluster.size()); + max = Math.max(max, cluster.size()); + } + float variance = m2 / (clusters.length - 1); + infoStream.message( + IVF_VECTOR_COMPONENT, + "Centroid count: " + + clusters.length + + " min: " + + min + + " max: " + + max + + " mean: " + + mean + + " stdDev: " + + Math.sqrt(variance) + + " variance: " + + variance); + } + + static void assignCentroids( + CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) + throws IOException { + short numCentroids = (short) scorer.size(); + // If soar > 0, then we actually need to apply the projection, otherwise, its just the second + // nearest centroid + // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible + int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); + int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); + // if lambda is `0`, that just means overspill to the second nearest, so we will only check the + // second nearest + if (SOAR_LAMBDA == 0) { + soarClusterCheckCount = Math.min(1, soarClusterCheckCount); + } + NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); + OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); + float[] scratch = new float[vectors.dimension()]; + for (int docID = 0; docID < vectors.size(); docID++) { + float[] vector = vectors.vectorValue(docID); + scorer.setScoringVector(vector); + int bestCentroid = 0; + float bestScore = Float.MAX_VALUE; + if (numCentroids > 1) { + for (short c = 0; c < numCentroids; c++) { + float squareDist = scorer.score(c); + neighborsToCheck.insertWithOverflow(c, squareDist); + } + // pop the best + int sz = neighborsToCheck.size(); + int best = + neighborsToCheck.consumeNodesAndScoresMin( + ordScoreIterator.ords, ordScoreIterator.scores); + // TODO yikes.... + ordScoreIterator.idx = sz; + bestScore = ordScoreIterator.getScore(best); + bestCentroid = ordScoreIterator.getOrd(best); + } + if (clusters[bestCentroid] == null) { + clusters[bestCentroid] = new IntArrayList(16); + } + clusters[bestCentroid].add(docID); + if (soarClusterCheckCount > 0) { + assignCentroidSOAR( + ordScoreIterator, + docID, + bestCentroid, + scorer.centroid(bestCentroid), + bestScore, + scratch, + scorer, + vectors, + clusters); + } + neighborsToCheck.clear(); + } + } + + static int prefilterCentroidAssignment( + int centroidOrd, + FloatVectorValues segmentCentroids, + CentroidAssignmentScorer scorer, + NeighborQueue neighborsToCheck, + int[] prefilteredCentroids) + throws IOException { + float[] segmentCentroid = segmentCentroids.vectorValue(centroidOrd); + scorer.setScoringVector(segmentCentroid); + neighborsToCheck.clear(); + for (short c = 0; c < scorer.size(); c++) { + float squareDist = scorer.score(c); + neighborsToCheck.insertWithOverflow(c, squareDist); + } + int size = neighborsToCheck.size(); + neighborsToCheck.consumeNodes(prefilteredCentroids); + return size; + } + + static void assignCentroidsMerge( + CentroidAssignmentScorer scorer, + FloatVectorValues vectors, + MergeState state, + String fieldName, + IntArrayList[] clusters) + throws IOException { + FixedBitSet assigned = new FixedBitSet(vectors.size() + 1); + short numCentroids = (short) scorer.size(); + // If soar > 0, then we actually need to apply the projection, otherwise, its just the second + // nearest centroid + // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible + int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); + int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); + // TODO is this the right to check? + // If cluster quality is higher, maybe we can reduce this... + int prefilteredCentroidCount = + Math.max(soarClusterCheckCount + 1, numCentroids / state.knnVectorsReaders.length); + NeighborQueue prefilteredCentroidsToCheck = new NeighborQueue(prefilteredCentroidCount, true); + NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); + OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); + int[] prefilteredCentroids = new int[prefilteredCentroidCount]; + float[] scratch = new float[vectors.dimension()]; + // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? + for (int idx = 0; idx < state.knnVectorsReaders.length; idx++) { + KnnVectorsReader reader = state.knnVectorsReaders[idx]; + IVFVectorsReader vectorsReader = getIVFReader(reader, fieldName); + // No reader, skip + if (vectorsReader == null) { + continue; + } + MergeState.DocMap docMap = state.docMaps[idx]; + var segmentCentroids = vectorsReader.getCentroids(state.fieldInfos[idx].fieldInfo(fieldName)); + for (int i = 0; i < segmentCentroids.size(); i++) { + IVFVectorsReader.CentroidInfo info = vectorsReader.centroidVectors(fieldName, i, docMap); + // Rare, but empty centroid, no point in doing comparisons + if (info.vectors().size == 0) { + continue; + } + prefilteredCentroidsToCheck.clear(); + int prefiltedCount = + prefilterCentroidAssignment( + i, segmentCentroids, scorer, prefilteredCentroidsToCheck, prefilteredCentroids); + int centroidVectorDocId = -1; + while ((centroidVectorDocId = info.vectors().nextVectorDocId()) != NO_MORE_DOCS) { + if (assigned.getAndSet(centroidVectorDocId)) { + continue; + } + neighborsToCheck.clear(); + float[] vector = info.vectors().vectorValue(); + scorer.setScoringVector(vector); + int bestCentroid; + float bestScore; + for (int c = 0; c < prefiltedCount; c++) { + float squareDist = scorer.score(prefilteredCentroids[c]); + neighborsToCheck.insertWithOverflow(prefilteredCentroids[c], squareDist); + } + int centroidCount = neighborsToCheck.size(); + int best = + neighborsToCheck.consumeNodesAndScoresMin( + ordScoreIterator.ords, ordScoreIterator.scores); + // yikes + ordScoreIterator.idx = centroidCount; + bestScore = ordScoreIterator.getScore(best); + bestCentroid = ordScoreIterator.getOrd(best); + if (clusters[bestCentroid] == null) { + clusters[bestCentroid] = new IntArrayList(16); + } + clusters[bestCentroid].add(info.vectors().docId()); + if (soarClusterCheckCount > 0) { + assignCentroidSOAR( + ordScoreIterator, + info.vectors().docId(), + bestCentroid, + scorer.centroid(bestCentroid), + bestScore, + scratch, + scorer, + vectors, + clusters); + } + } + } + } + + for (int vecOrd = 0; vecOrd < vectors.size(); vecOrd++) { + if (assigned.get(vecOrd)) { + continue; + } + float[] vector = vectors.vectorValue(vecOrd); + scorer.setScoringVector(vector); + int bestCentroid = 0; + float bestScore = Float.MAX_VALUE; + if (numCentroids > 1) { + for (short c = 0; c < numCentroids; c++) { + float squareDist = scorer.score(c); + neighborsToCheck.insertWithOverflow(c, squareDist); + } + int centroidCount = neighborsToCheck.size(); + int bestIdx = + neighborsToCheck.consumeNodesAndScoresMin( + ordScoreIterator.ords, ordScoreIterator.scores); + ordScoreIterator.idx = centroidCount; + bestCentroid = ordScoreIterator.getOrd(bestIdx); + bestScore = ordScoreIterator.getScore(bestIdx); + } + if (clusters[bestCentroid] == null) { + clusters[bestCentroid] = new IntArrayList(16); + } + int docID = vectors.ordToDoc(vecOrd); + clusters[bestCentroid].add(docID); + if (soarClusterCheckCount > 0) { + assignCentroidSOAR( + ordScoreIterator, + docID, + bestCentroid, + scorer.centroid(bestCentroid), + bestScore, + scratch, + scorer, + vectors, + clusters); + } + neighborsToCheck.clear(); + } + } + + static void assignCentroidSOAR( + OrdScoreIterator centroidsToCheck, + int docId, + int bestCentroidId, + float[] bestCentroid, + float bestScore, + float[] scratch, + CentroidAssignmentScorer scorer, + FloatVectorValues vectors, + IntArrayList[] clusters) + throws IOException { + float[] vector = vectors.vectorValue(docId); + ESVectorUtil.subtract(vector, bestCentroid, scratch); + int bestSecondaryCentroid = -1; + float minDist = Float.MAX_VALUE; + for (int i = 0; i < centroidsToCheck.size(); i++) { + float score = centroidsToCheck.getScore(i); + int centroidOrdinal = centroidsToCheck.getOrd(i); + if (centroidOrdinal == bestCentroidId) { + continue; + } + if (SOAR_LAMBDA > 0) { + float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch); + score += SOAR_LAMBDA * proj * proj / bestScore; + } + if (score < minDist) { + bestSecondaryCentroid = centroidOrdinal; + minDist = score; + } + } + if (bestSecondaryCentroid != -1) { + clusters[bestSecondaryCentroid].add(docId); + } + } + + static class OrdScoreIterator { + private final int[] ords; + private final float[] scores; + private int idx = 0; + + OrdScoreIterator(int size) { + this.ords = new int[size]; + this.scores = new float[size]; + } + + void add(int ord, float score) { + ords[idx] = ord; + scores[idx] = score; + idx++; + } + + int getOrd(int idx) { + return ords[idx]; + } + + float getScore(int idx) { + return scores[idx]; + } + + void reset() { + idx = 0; + } + + int size() { + return idx; + } + } + + // TODO unify with OSQ format + static class BinarizedFloatVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; + + private int lastOrd = -1; + + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + } + + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + + ord + + " than the quantization was done for: " + + lastOrd); + } + return corrections; + } + + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } + + private void binarize(int ord) throws IOException { + corrections = + quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid); + packAsBinary(initQuantized, binarized); + } + } + + static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + private final IndexInput centroidsInput; + private final int numCentroids; + private final int dimension; + private final float[] scratch; + private float[] q; + private final long centroidByteSize; + private int currOrd = -1; + + OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) { + this.centroidsInput = centroidsInput; + this.numCentroids = numCentroids; + this.dimension = info.getVectorDimension(); + this.scratch = new float[dimension]; + this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + } + + @Override + public int size() { + return numCentroids; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currOrd) { + return scratch; + } + centroidsInput.seek( + numCentroids * centroidByteSize + (long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.readFloats(scratch, 0, dimension); + this.currOrd = centroidOrdinal; + return scratch; + } + + @Override + public void setScoringVector(float[] vector) { + q = vector; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + } + } + + // TODO throw away rawCentroids + static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + private final float[][] centroids; + private float[] q; + + OnHeapCentroidAssignmentScorer(float[][] centroids) { + this.centroids = centroids; + } + + @Override + public int size() { + return centroids.length; + } + + @Override + public void setScoringVector(float[] vector) { + q = vector; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + return centroids[centroidOrdinal]; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + } + } + + static void writeQuantizedValue( + IndexOutput indexOutput, + byte[] binaryValue, + OptimizedScalarQuantizer.QuantizationResult corrections) + throws IOException { + indexOutput.writeBytes(binaryValue, binaryValue.length); + indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 + && corrections.quantizedComponentSum() <= 0xffff; + indexOutput.writeShort((short) corrections.quantizedComponentSum()); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java new file mode 100644 index 0000000000000..f7cf9a7bcdba5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -0,0 +1,92 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors; + +import java.io.IOException; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; + +/** + * Codec format for Inverted File Vector indexes. This index expects to break the dimensional space + * into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters + * are represented by centroids. + * + *

THe index is searcher by looking for the closest centroids to our vector query and then + * scoring the vectors in the posting list of the closest centroids. + */ +public class IVFVectorsFormat extends KnnVectorsFormat { + + public static final String IVF_VECTOR_COMPONENT = "IVF"; + public static final String NAME = "IVFVectorsFormat"; + // centroid ordinals -> centroid values, offsets + public static final String CENTROID_EXTENSION = "cenivf"; + // offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized + // vectors (OSQ bit) + public static final String CLUSTER_EXTENSION = "clivf"; + static final String IVF_META_EXTENSION = "mivf"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + private static final FlatVectorsFormat rawVectorFormat = + new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); + + private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000; + + private final int vectorPerCluster; + + public IVFVectorsFormat(int vectorPerCluster) { + super(NAME); + this.vectorPerCluster = vectorPerCluster; + } + + /** Constructs a format using the given graph construction parameters and scalar quantization. */ + public IVFVectorsFormat() { + this(DEFAULT_VECTORS_PER_CLUSTER); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new DefaultIVFVectorsWriter( + state, rawVectorFormat.fieldsWriter(state), vectorPerCluster); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "IVFVectorFormat"; + } + + static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof IVFVectorsReader reader) { + return reader; + } + return null; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java new file mode 100644 index 0000000000000..d8a14f55894c4 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -0,0 +1,474 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + + +package org.elasticsearch.index.codec.vectors; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.util.function.IntPredicate; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.ByteVectorValues; +import org.apache.lucene.index.CorruptIndexException; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.internal.hppc.IntObjectHashMap; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.ChecksumIndexInput; +import org.apache.lucene.store.DataInput; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.BitSet; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.FixedBitSet; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.hnsw.NeighborQueue; + +/** + * @lucene.experimental + */ +public abstract class IVFVectorsReader extends KnnVectorsReader { + + private final IndexInput ivfCentroids, ivfClusters; + private final SegmentReadState state; + private final FieldInfos fieldInfos; + protected final IntObjectHashMap fields; + private final FlatVectorsReader rawVectorsReader; + + protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) + throws IOException { + this.state = state; + this.fieldInfos = state.fieldInfos; + this.rawVectorsReader = rawVectorsReader; + this.fields = new IntObjectHashMap<>(); + String meta = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION); + + int versionMeta = -1; + boolean success = false; + try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) { + Throwable priorE = null; + try { + versionMeta = + CodecUtil.checkIndexHeader( + ivfMeta, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_START, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + readFields(ivfMeta); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(ivfMeta, priorE); + } + ivfCentroids = + openDataInput( + state, + versionMeta, + IVFVectorsFormat.CENTROID_EXTENSION, + IVFVectorsFormat.NAME, + state.context); + ivfClusters = + openDataInput( + state, + versionMeta, + IVFVectorsFormat.CLUSTER_EXTENSION, + IVFVectorsFormat.NAME, + state.context); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + abstract CentroidQueryScorer getCentroidScorer( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] target, + IndexInput clusters) + throws IOException; + + protected abstract FloatVectorValues getCentroids( + IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException; + + record CentroidInfo(CentroidFloatVectorValues vectors, float innerProduct) {} + + CentroidInfo centroidVectors(String fieldName, int centroidOrd, MergeState.DocMap docMap) + throws IOException { + FieldInfo info = state.fieldInfos.fieldInfo(fieldName); + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + if (entry.vectorEncoding() == VectorEncoding.BYTE) { + return null; + } + ivfClusters.seek(entry.postingListOffsets()[centroidOrd]); + int vectors = ivfClusters.readVInt(); + float innerProduct = Float.intBitsToFloat(ivfClusters.readInt()); + int[] vectorDocIds = new int[vectors]; + DocIdsWriter docIdsWriter = new DocIdsWriter(); + docIdsWriter.readInts(ivfClusters, vectors, vectorDocIds); + + // TODO this assumes that vectorDocIds are sorted!!! + int count = 0; + for (int i = 0; i < vectors; i++) { + int docId = vectorDocIds[i]; + if (docMap.get(docId) != -1) { + ++count; + } + } + // TODO: Do we need random access? If so, we should gather the ordinals here by + // iterating the valid docs in the docMap, keeping track of the valid ordinals, then they can + // be directly + // accessed + FloatVectorValues vectorValues = getFloatVectorValues(fieldName); + CentroidFloatVectorValues centroidFloatVectorValues = + new CentroidFloatVectorValues(vectorValues, vectorDocIds, docMap, count); + return new CentroidInfo(centroidFloatVectorValues, innerProduct); + } + + static class CentroidFloatVectorValues { + final FloatVectorValues vectorValues; + final int[] docIds; + final MergeState.DocMap docMap; + final int size; + int curOriginalDocId = -1; + int mappedDocID = -1; + KnnVectorValues.DocIndexIterator iterator; + + CentroidFloatVectorValues( + FloatVectorValues vectorValues, int[] docIds, MergeState.DocMap docMap, int size) { + this.vectorValues = vectorValues; + this.iterator = vectorValues.iterator(); + this.docIds = docIds; + this.docMap = docMap; + this.size = size; + } + + int docId() { + return mappedDocID; + } + + float[] vectorValue() throws IOException { + return vectorValues.vectorValue(iterator.index()); + } + + int nextVectorDocId() throws IOException { + while (curOriginalDocId < docIds.length - 1) { + curOriginalDocId++; + int doc = iterator.advance(docIds[curOriginalDocId]); + if (doc == NO_MORE_DOCS) { + return this.mappedDocID = NO_MORE_DOCS; + } + int mappedDoc = docMap.get(doc); + if (mappedDoc != -1) { + return this.mappedDocID = mappedDoc; + } + } + return this.mappedDocID = NO_MORE_DOCS; + } + } + + public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException { + FieldEntry entry = fields.get(fieldInfo.number); + if (entry == null) { + return null; + } + return getCentroids( + entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo); + } + + int centroidSize(String fieldName, int centroidOrdinal) throws IOException { + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName); + FieldEntry entry = fields.get(fieldInfo.number); + ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]); + return ivfClusters.readVInt(); + } + + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context) + throws IOException { + final String fileName = + IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + final IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + final int versionVectorData = + CodecUtil.checkIndexHeader( + in, + codecName, + IVFVectorsFormat.VERSION_START, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + + versionMeta + + ", " + + codecName + + "=" + + versionVectorData, + in); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } + } + + private void readFields(ChecksumIndexInput meta) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + final FieldInfo info = fieldInfos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + fields.put(info.number, readField(meta, info)); + } + } + + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final long centroidOffset = input.readLong(); + final long centroidLength = input.readLong(); + final int numPostingLists = input.readVInt(); + final long[] postingListOffsets = new long[numPostingLists]; + for (int i = 0; i < numPostingLists; i++) { + postingListOffsets[i] = input.readLong(); + } + final float[] globalCentroid = new float[info.getVectorDimension()]; + float globalCentroidDp = 0; + if (numPostingLists > 0) { + input.readFloats(globalCentroid, 0, globalCentroid.length); + globalCentroidDp = Float.intBitsToFloat(input.readInt()); + } + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction()); + } + return new FieldEntry( + similarityFunction, + vectorEncoding, + centroidOffset, + centroidLength, + postingListOffsets, + globalCentroid, + globalCentroidDp); + } + + private static VectorSimilarityFunction readSimilarityFunction(DataInput input) + throws IOException { + final int i = input.readInt(); + if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { + throw new IllegalArgumentException("invalid distance function: " + i); + } + return SIMILARITY_FUNCTIONS.get(i); + } + + private static VectorEncoding readVectorEncoding(DataInput input) throws IOException { + final int encodingId = input.readInt(); + if (encodingId < 0 || encodingId >= VectorEncoding.values().length) { + throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input); + } + return VectorEncoding.values()[encodingId]; + } + + @Override + public final void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(ivfCentroids); + CodecUtil.checksumEntireFile(ivfClusters); + } + + @Override + public final FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); + } + + @Override + public final ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); + } + + protected float[] getGlobalCentroid(FieldInfo info) { + if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) { + return null; + } + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + return entry.globalCentroid(); + } + + @Override + public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { + final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + return; + } + int nProbe = -1; + if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfStrategy) { + nProbe = ivfStrategy.getNProbe(); + } + float percentFiltered = 1f; + if (acceptDocs instanceof BitSet bitSet) { + percentFiltered = + Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); + } + int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); + BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1); + // TODO can we make a conjunction between idSetIterator and the acceptDocs? + IntPredicate needsScoring = + docId -> { + if (acceptDocs != null && acceptDocs.get(docId) == false) { + return false; + } + return visitedDocs.getAndSet(docId) == false; + }; + + FieldEntry entry = fields.get(fieldInfo.number); + CentroidQueryScorer centroidQueryScorer = + getCentroidScorer( + fieldInfo, + entry.postingListOffsets.length, + entry.centroidSlice(ivfCentroids), + target, + ivfClusters); + int centroidsToSearch = nProbe; + if (centroidsToSearch <= 0) { + centroidsToSearch = Math.max(((knnCollector.k() * 300) / 1_000), 1); + } + final NeighborQueue centroidQueue = + scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe); + PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); + int centroidsVisited = 0; + long expectedDocs = 0; + long actualDocs = 0; + // initially we visit only the "centroids to search" + while (centroidQueue.size() > 0 && centroidsVisited < centroidsToSearch) { + ++centroidsVisited; + // todo do we actually need to know the score??? + int centroidOrdinal = centroidQueue.pop(); + // todo do we need direct access to the raw centroid??? + expectedDocs += + scorer.resetPostingsScorer( + centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + actualDocs += scorer.visit(knnCollector); + } + if (acceptDocs != null) { + float unfilteredRatioVisited = (float) expectedDocs / numVectors; + int filteredVectors = (int) Math.ceil(numVectors * percentFiltered); + float expectedScored = + Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); + while (centroidQueue.size() > 0 + && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { + int centroidOrdinal = centroidQueue.pop(); + scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + actualDocs += scorer.visit(knnCollector); + } + } + } + + @Override + public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) + throws IOException { + final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); + final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field); + for (int i = 0; i < values.size(); i++) { + final float score = + fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i)); + knnCollector.collect(values.ordToDoc(i), score); + if (knnCollector.earlyTerminated()) { + return; + } + } + } + + abstract NeighborQueue scorePostingLists( + FieldInfo fieldInfo, + KnnCollector knnCollector, + CentroidQueryScorer centroidQueryScorer, + int nProbe) + throws IOException; + + @Override + public void close() throws IOException { + IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters); + } + + protected record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long centroidOffset, + long centroidLength, + long[] postingListOffsets, + float[] globalCentroid, + float globalCentroidDp) { + IndexInput centroidSlice(IndexInput centroidFile) throws IOException { + return centroidFile.slice("centroids", centroidOffset, centroidLength); + } + } + + abstract PostingVisitor getPostingVisitor( + FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) + throws IOException; + + interface CentroidQueryScorer { + int size(); + + float[] centroid(int centroidOrdinal) throws IOException; + + float score(int centroidOrdinal) throws IOException; + } + + interface PostingVisitor { + // TODO maybe we can not specifically pass the centroid... + + /** returns the number of documents in the posting list */ + int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException; + + /** returns the number of scored documents */ + int visit(KnnCollector collector) throws IOException; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java new file mode 100644 index 0000000000000..4011576ecd47f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -0,0 +1,498 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + + +package org.elasticsearch.index.codec.vectors; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; +import org.apache.lucene.codecs.CodecUtil; +import org.apache.lucene.codecs.KnnFieldVectorsWriter; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatFieldVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.IndexFileNames; +import org.apache.lucene.index.KnnVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.index.Sorter; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.IOUtils; +import org.apache.lucene.util.InfoStream; +import org.apache.lucene.util.VectorUtil; + +/** + * @lucene.experimental + */ +public abstract class IVFVectorsWriter extends KnnVectorsWriter { + + private final List fieldWriters = new ArrayList<>(); + private final IndexOutput ivfCentroids, ivfClusters; + private final IndexOutput ivfMeta; + private final FlatVectorsWriter rawVectorDelegate; + private final SegmentWriteState segmentWriteState; + + protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) + throws IOException { + this.segmentWriteState = state; + this.rawVectorDelegate = rawVectorDelegate; + final String metaFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION); + + final String ivfCentroidsFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.CENTROID_EXTENSION); + final String ivfClustersFileName = + IndexFileNames.segmentFileName( + state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.CLUSTER_EXTENSION); + boolean success = false; + try { + ivfMeta = state.directory.createOutput(metaFileName, state.context); + CodecUtil.writeIndexHeader( + ivfMeta, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context); + CodecUtil.writeIndexHeader( + ivfCentroids, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context); + CodecUtil.writeIndexHeader( + ivfClusters, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } + } + + @Override + public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) { + throw new IllegalArgumentException("IVF does not support cosine similarity"); + } + final FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + final FlatFieldVectorsWriter floatWriter = + (FlatFieldVectorsWriter) rawVectorDelegate; + fieldWriters.add(new FieldWriter(fieldInfo, floatWriter)); + } + return rawVectorDelegate; + } + + protected abstract int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid) + throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidAssignmentScorer scorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState) + throws IOException; + + abstract CentroidAssignmentScorer calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + float[] globalCentroid) + throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + InfoStream infoStream, + CentroidAssignmentScorer scorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput) + throws IOException; + + abstract CentroidAssignmentScorer createCentroidScorer( + IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) + throws IOException; + + @Override + public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter fieldWriter : fieldWriters) { + float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; + ESVectorUtil.calculateCentroid(fieldWriter.delegate().getVectors(), globalCentroid); + // build a float vector values with random access + final FloatVectorValues floatVectorValues = + getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); + // build centroids + long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + final CentroidAssignmentScorer centroidAssignmentScorer = + calculateAndWriteCentroids( + fieldWriter.fieldInfo, floatVectorValues, ivfCentroids, globalCentroid); + long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + final long[] offsets = + buildAndWritePostingsLists( + fieldWriter.fieldInfo, + segmentWriteState.infoStream, + centroidAssignmentScorer, + floatVectorValues, + ivfClusters); + // write posting lists + writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + } + } + + private static FloatVectorValues getFloatVectorValues( + FieldInfo fieldInfo, FlatFieldVectorsWriter fieldVectorsWriter, int maxDoc) + throws IOException { + List vectors = fieldVectorsWriter.getVectors(); + if (vectors.size() == maxDoc) { + return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension()); + } + final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator(); + final int[] docIds = new int[vectors.size()]; + for (int i = 0; i < docIds.length; i++) { + docIds[i] = iterator.nextDoc(); + } + assert iterator.nextDoc() == NO_MORE_DOCS; + return new FloatVectorValues() { + @Override + public float[] vectorValue(int ord) { + return vectors.get(ord); + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return fieldInfo.getVectorDimension(); + } + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int ordToDoc(int ord) { + return docIds[ord]; + } + }; + } + + static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof IVFVectorsReader reader) { + return reader; + } + return null; + } + + @Override + public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final int numVectors; + String name = null; + boolean success = false; + // build a float vector values with random access. In order to do that we dump the vectors to + // a temporary file + // and write the docID follow by the vector + try (IndexOutput out = + mergeState.segmentInfo.dir.createTempOutput( + mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { + name = out.getName(); + // TODO do this better, we shouldn't have to write to a temp file, we should be able to + // to just from the merged vector values. + numVectors = + writeFloatVectorValues( + fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + success = true; + } finally { + if (success == false && name != null) { + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + } + } + float[] globalCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = 0; + for (var knnReaders : mergeState.knnVectorsReaders) { + IVFVectorsReader ivfReader = getIVFReader(knnReaders, fieldInfo.name); + if (ivfReader != null) { + int numVecs = ivfReader.getFloatVectorValues(fieldInfo.name).size(); + float[] readerGlobalCentroid = ivfReader.getGlobalCentroid(fieldInfo); + if (readerGlobalCentroid != null) { + vectorCount += numVecs; + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] += readerGlobalCentroid[i] * numVecs; + } + } + } + } + if (vectorCount > 0) { + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] /= vectorCount; + } + } + try (IndexInput in = mergeState.segmentInfo.dir.openInput(name, IOContext.DEFAULT)) { + final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); + success = false; + CentroidAssignmentScorer centroidAssignmentScorer; + long centroidOffset; + long centroidLength; + String centroidTempName = null; + int numCentroids; + IndexOutput centroidTemp = null; + try { + centroidTemp = + mergeState.segmentInfo.dir.createTempOutput( + mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); + centroidTempName = centroidTemp.getName(); + numCentroids = + calculateAndWriteCentroids( + fieldInfo, floatVectorValues, centroidTemp, mergeState, globalCentroid); + success = true; + } finally { + if (success == false && centroidTempName != null) { + IOUtils.closeWhileHandlingException(centroidTemp); + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + } + } + try { + if (numCentroids == 0) { + centroidOffset = ivfCentroids.getFilePointer(); + writeMeta(fieldInfo, centroidOffset, 0, new long[0], null); + CodecUtil.writeFooter(centroidTemp); + IOUtils.close(centroidTemp); + return; + } + CodecUtil.writeFooter(centroidTemp); + IOUtils.close(centroidTemp); + centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + try (IndexInput centroidInput = + mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { + ivfCentroids.copyBytes( + centroidInput, centroidInput.length() - CodecUtil.footerLength()); + centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + centroidAssignmentScorer = + createCentroidScorer(centroidInput, numCentroids, fieldInfo, globalCentroid); + assert centroidAssignmentScorer.size() == numCentroids; + // build a float vector values with random access + // build centroids + final long[] offsets = + buildAndWritePostingsLists( + fieldInfo, + centroidAssignmentScorer, + floatVectorValues, + ivfClusters, + mergeState); + // write posting lists + + // TODO handle this correctly by creating new centroid + if (vectorCount == 0 && offsets.length > 0) { + throw new IllegalStateException( + "No global centroid found for field: " + fieldInfo.name); + } + assert offsets.length == centroidAssignmentScorer.size(); + writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + } + } finally { + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + } + } finally { + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + } + } + } + + private static FloatVectorValues getFloatVectorValues( + FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) { + final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES; + final float[] vector = new float[fieldInfo.getVectorDimension()]; + return new FloatVectorValues() { + @Override + public float[] vectorValue(int ord) throws IOException { + randomAccessInput.seek(ord * length + Integer.BYTES); + randomAccessInput.readFloats(vector, 0, vector.length); + return vector; + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return fieldInfo.getVectorDimension(); + } + + @Override + public int size() { + return numVectors; + } + + @Override + public int ordToDoc(int ord) { + try { + randomAccessInput.seek(ord * length); + return randomAccessInput.readInt(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + }; + } + + private static int writeFloatVectorValues( + FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues) + throws IOException { + int numVectors = 0; + final ByteBuffer buffer = + ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES) + .order(ByteOrder.LITTLE_ENDIAN); + final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + numVectors++; + float[] vector = floatVectorValues.vectorValue(iterator.index()); + out.writeInt(iterator.docID()); + buffer.asFloatBuffer().put(vector); + out.writeBytes(buffer.array(), buffer.array().length); + } + return numVectors; + } + + private void writeMeta( + FieldInfo field, + long centroidOffset, + long centroidLength, + long[] offsets, + float[] globalCentroid) + throws IOException { + ivfMeta.writeInt(field.number); + ivfMeta.writeInt(field.getVectorEncoding().ordinal()); + ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); + ivfMeta.writeLong(centroidOffset); + ivfMeta.writeLong(centroidLength); + ivfMeta.writeVInt(offsets.length); + for (long offset : offsets) { + ivfMeta.writeLong(offset); + } + if (offsets.length > 0) { + final ByteBuffer buffer = + ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(globalCentroid); + ivfMeta.writeBytes(buffer.array(), buffer.array().length); + ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid))); + } + } + + private static int distFuncToOrd(VectorSimilarityFunction func) { + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { + return (byte) i; + } + } + throw new IllegalArgumentException("invalid distance function: " + func); + } + + @Override + public final void finish() throws IOException { + rawVectorDelegate.finish(); + if (ivfMeta != null) { + // write end of fields marker + ivfMeta.writeInt(-1); + CodecUtil.writeFooter(ivfMeta); + } + if (ivfCentroids != null) { + CodecUtil.writeFooter(ivfCentroids); + } + if (ivfClusters != null) { + CodecUtil.writeFooter(ivfClusters); + } + } + + @Override + public final void close() throws IOException { + IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters); + } + + @Override + public final long ramBytesUsed() { + return rawVectorDelegate.ramBytesUsed(); + } + + private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter delegate) {} + + interface CentroidAssignmentScorer { + CentroidAssignmentScorer EMPTY = + new CentroidAssignmentScorer() { + @Override + public int size() { + return 0; + } + + @Override + public float[] centroid(int centroidOrdinal) { + throw new IllegalStateException("No centroids"); + } + + @Override + public float score(int centroidOrdinal) { + throw new IllegalStateException("No centroids"); + } + + @Override + public void setScoringVector(float[] vector) { + throw new IllegalStateException("No centroids"); + } + }; + + int size(); + + float[] centroid(int centroidOrdinal) throws IOException; + + void setScoringVector(float[] vector); + + float score(int centroidOrdinal) throws IOException; + } +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java new file mode 100644 index 0000000000000..ce6f8d07baeb6 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java @@ -0,0 +1,162 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.apache.lucene.util.LongHeap; +import org.apache.lucene.util.NumericUtils; + +/** + * Copied from and modified from Apache Lucene. + */ +class NeighborQueue { + + private enum Order { + MIN_HEAP { + @Override + long apply(long v) { + return v; + } + }, + MAX_HEAP { + @Override + long apply(long v) { + // This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It + // needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa. + return -1 - v; + } + }; + + abstract long apply(long v); + } + + private final LongHeap heap; + private final Order order; + + NeighborQueue(int initialSize, boolean maxHeap) { + this.heap = new LongHeap(initialSize); + this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP; + } + + /** + * @return the number of elements in the heap + */ + public int size() { + return heap.size(); + } + + /** + * Adds a new graph arc, extending the storage as needed. + * + * @param newNode the neighbor node id + * @param newScore the score of the neighbor, relative to some other node + */ + public void add(int newNode, float newScore) { + heap.push(encode(newNode, newScore)); + } + + /** + * If the heap is not full (size is less than the initialSize provided to the constructor), adds a + * new node-and-score element. If the heap is full, compares the score against the current top + * score, and replaces the top element if newScore is better than (greater than unless the heap is + * reversed), the current top score. + * + * @param newNode the neighbor node id + * @param newScore the score of the neighbor, relative to some other node + */ + public boolean insertWithOverflow(int newNode, float newScore) { + return heap.insertWithOverflow(encode(newNode, newScore)); + } + + /** + * Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule + * that when two scores are equal, the smaller node ID must win. + * + *

The most significant 32 bits represent the float score, encoded as a sortable int. + * + *

The least significant 32 bits represent the node ID. + * + *

The bits representing the node ID are complemented to guarantee the win for the smaller node + * Id. + * + *

The AND with 0xFFFFFFFFL (a long with first 32 bits as 1) is necessary to obtain a long that + * has: + *

  • The most significant 32 bits set to 0 + *
  • The least significant 32 bits represent the node ID. + * + * @param node the node ID + * @param score the node score + * @return the encoded score, node ID + */ + private long encode(int node, float score) { + return order.apply( + (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); + } + + private float decodeScore(long heapValue) { + return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32)); + } + + private int decodeNodeId(long heapValue) { + return (int) ~(order.apply(heapValue)); + } + + /** Removes the top element and returns its node id. */ + public int pop() { + return decodeNodeId(heap.pop()); + } + + public void consumeNodes(int[] dest) { + if (dest.length < size()) { + throw new IllegalArgumentException( + "Destination array is too small. Expected at least " + size() + " elements."); + } + for (int i = 0; i < size(); i++) { + dest[i] = decodeNodeId(heap.get(i + 1)); + } + } + + public int consumeNodesAndScoresMin(int[] dest, float[] scores) { + if (dest.length < size() || scores.length < size()) { + throw new IllegalArgumentException( + "Destination array is too small. Expected at least " + size() + " elements."); + } + float bestScore = Float.POSITIVE_INFINITY; + int bestIdx = 0; + for (int i = 0; i < size(); i++) { + long heapValue = heap.get(i + 1); + scores[i] = decodeScore(heapValue); + dest[i] = decodeNodeId(heapValue); + if (scores[i] < bestScore) { + bestScore = scores[i]; + bestIdx = i; + } + } + return bestIdx; + } + + public void clear() { + heap.clear(); + } + + @Override + public String toString() { + return "Neighbors[" + heap.size() + "]"; + } +} From c1e0744b7da5c8fdf77085c98bfe7066a876ebfe Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:02:57 -0400 Subject: [PATCH 02/11] iter --- .../elasticsearch/simdvec/ESVectorUtil.java | 37 +- .../DefaultESVectorUtilSupport.java | 12 + .../vectorization/ESVectorUtilSupport.java | 3 + .../ESVectorizationProvider.java | 1 + .../MemorySegmentES91OSQVectorsScorer.java | 1 + .../PanamaESVectorUtilSupport.java | 43 + .../PanamaESVectorizationProvider.java | 1 + .../simdvec/ESVectorUtilTests.java | 16 + server/src/main/java/module-info.java | 3 +- .../vectors/DefaultIVFVectorsReader.java | 726 ++++----- .../vectors/DefaultIVFVectorsWriter.java | 1451 ++++++++--------- .../index/codec/vectors/IVFVectorsFormat.java | 127 +- .../index/codec/vectors/IVFVectorsReader.java | 676 ++++---- .../index/codec/vectors/IVFVectorsWriter.java | 875 +++++----- .../index/codec/vectors/NeighborQueue.java | 234 ++- .../org.apache.lucene.codecs.KnnVectorsFormat | 1 + .../codec/vectors/IVFVectorsFormatTests.java | 64 + 17 files changed, 2090 insertions(+), 2181 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 9212d5c83bd6a..50b8e18c3d224 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -43,7 +43,6 @@ public class ESVectorUtil { private static final ESVectorUtilSupport IMPL = ESVectorizationProvider.getInstance().getVectorUtilSupport(); - public static ES91OSQVectorsScorer getES91OSQVectorsScorer(IndexInput input, int dimension) throws IOException { return ESVectorizationProvider.getInstance().newES91OSQVectorsScorer(input, dimension); } @@ -218,4 +217,40 @@ public static void centerAndCalculateOSQStatsDp(float[] target, float[] centroid assert stats.length == 6; IMPL.centerAndCalculateOSQStatsDp(target, centroid, centered, stats); } + + /** + * Calculates the difference between two vectors and stores the result in a third vector. + * @param v1 the first vector + * @param v2 the second vector + * @param result the result vector, must be the same length as the input vectors + */ + public static void subtract(float[] v1, float[] v2, float[] result) { + if (v1.length != v2.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + v2.length); + } + if (result.length != v1.length) { + throw new IllegalArgumentException("vector dimensions differ: " + result.length + "!=" + v1.length); + } + for (int i = 0; i < v1.length; i++) { + result[i] = v1[i] - v2[i]; + } + } + + /** + * calculates the spill-over score for a vector and a centroid, given its residual with + * its actually nearest centroid + * @param v1 the vector + * @param centroid the centroid + * @param originalResidual the residual with the actually nearest centroid + * @return the spill-over score (soar) + */ + public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + if (v1.length != centroid.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length); + } + if (originalResidual.length != v1.length) { + throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length); + } + return IMPL.soarResidual(v1, centroid, originalResidual); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index ce8fce7e68b7a..846472876f378 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -138,6 +138,18 @@ public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float stats[5] = centroidDot; } + @Override + public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + assert v1.length == centroid.length; + assert v1.length == originalResidual.length; + float proj = 0; + for (int i = 0; i < v1.length; i++) { + float djk = v1[i] - centroid[i]; + proj = fma(djk, originalResidual[i], proj); + } + return proj; + } + public static int ipByteBitImpl(byte[] q, byte[] d) { return ipByteBitImpl(q, d, 0); } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index b2615c55e64ec..8aa50e8c42805 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -28,4 +28,7 @@ public interface ESVectorUtilSupport { void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats); void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats); + + float soarResidual(float[] v1, float[] centroid, float[] originalResidual); + } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java index af5df094659e5..ea4180b595657 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorizationProvider.java @@ -13,6 +13,7 @@ import org.apache.lucene.util.Constants; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.util.Locale; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java index 0bf3b8da22d2c..46daa074c5e5e 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java @@ -20,6 +20,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index cf856c5322f06..1d8f59f855675 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -367,6 +367,49 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa return (1f - lambda) * xe * xe / norm2 + lambda * e; } + @Override + public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + assert v1.length == centroid.length; + assert v1.length == originalResidual.length; + float proj = 0; + int i = 0; + if (v1.length > 2 * FLOAT_SPECIES.length()) { + FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES); + int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length(); + for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) { + // one + FloatVector v1Vec0 = FloatVector.fromArray(FLOAT_SPECIES, v1, i); + FloatVector centroidVec0 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i); + FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); + FloatVector djkVec0 = v1Vec0.sub(centroidVec0); + projVec1 = fma(djkVec0, originalResidualVec0, projVec1); + + // two + FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length()); + FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length()); + FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length()); + FloatVector djkVec1 = v1Vec1.sub(centroidVec1); + projVec2 = fma(djkVec1, originalResidualVec1, projVec2); + } + // vector tail + for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) { + FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i); + FloatVector centroidVec = FloatVector.fromArray(FLOAT_SPECIES, centroid, i); + FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); + FloatVector djkVec = v1Vec.sub(centroidVec); + projVec1 = fma(djkVec, originalResidualVec, projVec1); + } + proj += projVec1.add(projVec2).reduceLanes(ADD); + } + // tail + for (; i < v1.length; i++) { + float djk = v1[i] - centroid[i]; + proj = fma(djk, originalResidual[i], proj); + } + return proj; + } + private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; private static final VectorSpecies BYTE_SPECIES_256 = ByteVector.SPECIES_256; diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java index c409be4fb37d8..5ff8c19c90a56 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorizationProvider.java @@ -11,6 +11,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.MemorySegmentAccessInput; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; import java.lang.foreign.MemorySegment; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 0c99fad2d3d5c..abd4e3b0be045 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -268,6 +268,22 @@ public void testOsqGridPoints() { } } + public void testSoarOverspillScore() { + int size = random().nextInt(128, 512); + float deltaEps = 1e-5f * size; + var vector = new float[size]; + var centroid = new float[size]; + var preResidual = new float[size]; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + centroid[i] = random().nextFloat(); + preResidual[i] = random().nextFloat(); + } + var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual); + var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual); + assertEquals(expected, result, deltaEps); + } + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 33f3d83393709..38e3be0d3b13f 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -455,7 +455,8 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, - org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat; + org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, + org.elasticsearch.index.codec.vectors.IVFVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 7524ae3558faa..f555ce2ba9113 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -9,16 +9,6 @@ package org.elasticsearch.index.codec.vectors; -import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; -import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; -import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; -import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; - -import java.io.IOException; -import java.util.function.IntPredicate; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -33,398 +23,396 @@ import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.ESVectorUtil; +import java.io.IOException; +import java.util.function.IntPredicate; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; +import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.transposeHalfByte; +import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; + /** * Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using * brute force and then scores the top ones using the posting list. */ public class DefaultIVFVectorsReader extends IVFVectorsReader { - private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); - - public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) - throws IOException { - super(state, rawVectorsReader); - } - - @Override - protected CentroidQueryScorer getCentroidScorer( - FieldInfo fieldInfo, - int numCentroids, - IndexInput centroids, - float[] targetQuery, - IndexInput clusters) - throws IOException { - FieldEntry fieldEntry = fields.get(fieldInfo.number); - float[] globalCentroid = fieldEntry.globalCentroid(); - float globalCentroidDp = fieldEntry.globalCentroidDp(); - OptimizedScalarQuantizer scalarQuantizer = - new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - byte[] quantized = new byte[targetQuery.length]; - float[] targetScratch = ArrayUtil.copyArray(targetQuery); - OptimizedScalarQuantizer.QuantizationResult queryParams = - scalarQuantizer.scalarQuantize(targetScratch, quantized, (byte) 4, globalCentroid); - return new CentroidQueryScorer() { - int currentCentroid = -1; - private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()]; - private final float[] centroid = new float[fieldInfo.getVectorDimension()]; - private final float[] centroidCorrectiveValues = new float[3]; - private int quantizedCentroidComponentSum; - private final long centroidByteSize = - fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; - - @Override - public int size() { - return numCentroids; - } - - @Override - public float[] centroid(int centroidOrdinal) throws IOException { - readQuantizedCentroid(centroidOrdinal); - return centroid; - } - - private void readQuantizedCentroid(int centroidOrdinal) throws IOException { - if (centroidOrdinal == currentCentroid) { - return; - } - centroids.seek(centroidOrdinal * centroidByteSize); - quantizedCentroidComponentSum = - readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues); - centroids.seek( - numCentroids * centroidByteSize - + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal); - centroids.readFloats(centroid, 0, centroid.length); - currentCentroid = centroidOrdinal; - } - - @Override - public float score(int centroidOrdinal) throws IOException { - readQuantizedCentroid(centroidOrdinal); - return int4QuantizedScore( - quantized, - queryParams, - fieldInfo.getVectorDimension(), - quantizedCentroid, - centroidCorrectiveValues, - quantizedCentroidComponentSum, - globalCentroidDp, - fieldInfo.getVectorSimilarityFunction()); - } - }; - } - - @Override - protected FloatVectorValues getCentroids( - IndexInput indexInput, int numCentroids, FieldInfo info) { - FieldEntry entry = fields.get(info.number); - if (entry == null) { - return null; - } - return new OffHeapCentroidFloatVectorValues( - numCentroids, indexInput, info.getVectorDimension()); - } - - @Override - NeighborQueue scorePostingLists( - FieldInfo fieldInfo, - KnnCollector knnCollector, - CentroidQueryScorer centroidQueryScorer, - int nProbe) - throws IOException { - NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true); - // TODO Off heap scoring for quantized centroids? - for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) { - neighborQueue.add(centroid, centroidQueryScorer.score(centroid)); - } - return neighborQueue; - } - - @Override - PostingVisitor getPostingVisitor( - FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) - throws IOException { - FieldEntry entry = fields.get(fieldInfo.number); - return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring); - } - - // TODO can we do this in off-heap blocks? - static float int4QuantizedScore( - byte[] quantizedQuery, - OptimizedScalarQuantizer.QuantizationResult queryCorrections, - int dims, - byte[] binaryCode, - float[] targetCorrections, - int targetComponentSum, - float centroidDp, - VectorSimilarityFunction similarityFunction) { - float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode); - float ax = targetCorrections[0]; - // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary - float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); - float score = - ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; - if (similarityFunction == EUCLIDEAN) { - score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; - return Math.max(1 / (1f + score), 0); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - return VectorUtil.scaleMaxInnerProductScore(score); - } - return Math.max((1f + score) / 2f, 0); - } - } - - static class OffHeapCentroidFloatVectorValues extends FloatVectorValues { - private final int numCentroids; - private final IndexInput input; - private final int dimension; - private final float[] centroid; - private final long centroidByteSize; - private int ord = -1; - - OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) { - this.numCentroids = numCentroids; - this.input = input; - this.dimension = dimension; - this.centroid = new float[dimension]; - this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + super(state, rawVectorsReader); } @Override - public float[] vectorValue(int ord) throws IOException { - if (ord < 0 || ord >= numCentroids) { - throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]"); - } - if (ord == this.ord) { - return centroid; - } - readQuantizedCentroid(ord); - return centroid; - } + CentroidQueryScorer getCentroidScorer( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] targetQuery, + IndexInput clusters + ) throws IOException { + FieldEntry fieldEntry = fields.get(fieldInfo.number); + float[] globalCentroid = fieldEntry.globalCentroid(); + float globalCentroidDp = fieldEntry.globalCentroidDp(); + OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + byte[] quantized = new byte[targetQuery.length]; + float[] targetScratch = ArrayUtil.copyArray(targetQuery); + OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( + targetScratch, + quantized, + (byte) 4, + globalCentroid + ); + return new CentroidQueryScorer() { + int currentCentroid = -1; + private final byte[] quantizedCentroid = new byte[fieldInfo.getVectorDimension()]; + private final float[] centroid = new float[fieldInfo.getVectorDimension()]; + private final float[] centroidCorrectiveValues = new float[3]; + private int quantizedCentroidComponentSum; + private final long centroidByteSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + + @Override + public int size() { + return numCentroids; + } - private void readQuantizedCentroid(int centroidOrdinal) throws IOException { - if (centroidOrdinal == ord) { - return; - } - input.seek( - numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal); - input.readFloats(centroid, 0, centroid.length); - ord = centroidOrdinal; + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + readQuantizedCentroid(centroidOrdinal); + return centroid; + } + + private void readQuantizedCentroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currentCentroid) { + return; + } + centroids.seek(centroidOrdinal * centroidByteSize); + quantizedCentroidComponentSum = readQuantizedValue(centroids, quantizedCentroid, centroidCorrectiveValues); + centroids.seek(numCentroids * centroidByteSize + (long) Float.BYTES * quantizedCentroid.length * centroidOrdinal); + centroids.readFloats(centroid, 0, centroid.length); + currentCentroid = centroidOrdinal; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + readQuantizedCentroid(centroidOrdinal); + return int4QuantizedScore( + quantized, + queryParams, + fieldInfo.getVectorDimension(), + quantizedCentroid, + centroidCorrectiveValues, + quantizedCentroidComponentSum, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction() + ); + } + }; } @Override - public int dimension() { - return dimension; + protected FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) { + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + return new OffHeapCentroidFloatVectorValues(numCentroids, indexInput, info.getVectorDimension()); } @Override - public int size() { - return numCentroids; + NeighborQueue scorePostingLists(FieldInfo fieldInfo, KnnCollector knnCollector, CentroidQueryScorer centroidQueryScorer, int nProbe) + throws IOException { + NeighborQueue neighborQueue = new NeighborQueue(centroidQueryScorer.size(), true); + // TODO Off heap scoring for quantized centroids? + for (int centroid = 0; centroid < centroidQueryScorer.size(); centroid++) { + neighborQueue.add(centroid, centroidQueryScorer.score(centroid)); + } + return neighborQueue; } @Override - public FloatVectorValues copy() throws IOException { - return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension); - } - } - - private static class MemorySegmentPostingsVisitor implements PostingVisitor { - final long quantizedByteLength; - final IndexInput indexInput; - final float[] target; - final FieldEntry entry; - final FieldInfo fieldInfo; - final IntPredicate needsScoring; - private final ES91OSQVectorsScorer osqVectorsScorer; - final float[] scores = new float[BULK_SIZE]; - final float[] correctionsLower = new float[BULK_SIZE]; - final float[] correctionsUpper = new float[BULK_SIZE]; - final int[] correctionsSum = new int[BULK_SIZE]; - final float[] correctionsAdd = new float[BULK_SIZE]; - - int[] docIdsScratch = new int[0]; - int vectors; - boolean quantized = false; - float centroidDp; - float[] centroid; - long slicePos; - OptimizedScalarQuantizer.QuantizationResult queryCorrections; - DocIdsWriter docIdsWriter = new DocIdsWriter(); - - final float[] scratch; - final byte[] quantizationScratch; - final byte[] quantizedQueryScratch; - final OptimizedScalarQuantizer quantizer; - final float[] correctiveValues = new float[3]; - final long quantizedVectorByteSize; - - MemorySegmentPostingsVisitor( - float[] target, - IndexInput indexInput, - FieldEntry entry, - FieldInfo fieldInfo, - IntPredicate needsScoring) + PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, IntPredicate needsScoring) throws IOException { - this.target = target; - this.indexInput = indexInput; - this.entry = entry; - this.fieldInfo = fieldInfo; - this.needsScoring = needsScoring; - - scratch = new float[target.length]; - quantizationScratch = new byte[target.length]; - final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); - quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; - quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; - quantizedVectorByteSize = (discretizedDimensions / 8); - quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - osqVectorsScorer = - ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); + FieldEntry entry = fields.get(fieldInfo.number); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, needsScoring); } - @Override - public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException { - quantized = false; - indexInput.seek(entry.postingListOffsets()[centroidOrdinal]); - vectors = indexInput.readVInt(); - centroidDp = Float.intBitsToFloat(indexInput.readInt()); - this.centroid = centroid; - // read the doc ids - docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; - docIdsWriter.readInts(indexInput, vectors, docIdsScratch); - slicePos = indexInput.getFilePointer(); - return vectors; + // TODO can we do this in off-heap blocks? + static float int4QuantizedScore( + byte[] quantizedQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + int dims, + byte[] binaryCode, + float[] targetCorrections, + int targetComponentSum, + float centroidDp, + VectorSimilarityFunction similarityFunction + ) { + float qcDist = VectorUtil.int4DotProduct(quantizedQuery, binaryCode); + float ax = targetCorrections[0]; + // Here we assume `lx` is simply bit vectors, so the scaling isn't necessary + float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; + float ay = queryCorrections.lowerInterval(); + float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; + float y1 = queryCorrections.quantizedComponentSum(); + float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; + if (similarityFunction == EUCLIDEAN) { + score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; + return Math.max(1 / (1f + score), 0); + } else { + // For cosine and max inner product, we need to apply the additional correction, which is + // assumed to be the non-centered dot-product between the vector and the centroid + score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; + if (similarityFunction == MAXIMUM_INNER_PRODUCT) { + return VectorUtil.scaleMaxInnerProductScore(score); + } + return Math.max((1f + score) / 2f, 0); + } } - void scoreIndividually(int offset) throws IOException { - // score individually, first the quantized byte chunk - for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[j + offset]; - if (doc != -1) { - indexInput.seek( - slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); - float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); - scores[j] = qcDist; + static class OffHeapCentroidFloatVectorValues extends FloatVectorValues { + private final int numCentroids; + private final IndexInput input; + private final int dimension; + private final float[] centroid; + private final long centroidByteSize; + private int ord = -1; + + OffHeapCentroidFloatVectorValues(int numCentroids, IndexInput input, int dimension) { + this.numCentroids = numCentroids; + this.input = input; + this.dimension = dimension; + this.centroid = new float[dimension]; + this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + } + + @Override + public float[] vectorValue(int ord) throws IOException { + if (ord < 0 || ord >= numCentroids) { + throw new IllegalArgumentException("ord must be in [0, " + numCentroids + "]"); + } + if (ord == this.ord) { + return centroid; + } + readQuantizedCentroid(ord); + return centroid; + } + + private void readQuantizedCentroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == ord) { + return; + } + input.seek(numCentroids * centroidByteSize + (long) Float.BYTES * dimension * centroidOrdinal); + input.readFloats(centroid, 0, centroid.length); + ord = centroidOrdinal; } - } - // read in all corrections - indexInput.seek( - slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); - indexInput.readFloats(correctionsLower, 0, BULK_SIZE); - indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); - for (int j = 0; j < BULK_SIZE; j++) { - correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort()); - } - indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); - // Now apply corrections - for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[offset + j]; - if (doc != -1) { - scores[j] = - osqVectorsScorer.score( - queryCorrections, - fieldInfo.getVectorSimilarityFunction(), - centroidDp, - correctionsLower[j], - correctionsUpper[j], - correctionsSum[j], - correctionsAdd[j], - scores[j]); + + @Override + public int dimension() { + return dimension; + } + + @Override + public int size() { + return numCentroids; + } + + @Override + public FloatVectorValues copy() throws IOException { + return new OffHeapCentroidFloatVectorValues(numCentroids, input.clone(), dimension); } - } } - @Override - public int visit(KnnCollector knnCollector) throws IOException { - // block processing - int scoredDocs = 0; - int limit = vectors - BULK_SIZE + 1; - int i = 0; - for (; i < limit; i += BULK_SIZE) { - int docsToScore = BULK_SIZE; - for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[i + j]; - if (needsScoring.test(doc) == false) { - docIdsScratch[i + j] = -1; - docsToScore--; - } + private static class MemorySegmentPostingsVisitor implements PostingVisitor { + final long quantizedByteLength; + final IndexInput indexInput; + final float[] target; + final FieldEntry entry; + final FieldInfo fieldInfo; + final IntPredicate needsScoring; + private final ES91OSQVectorsScorer osqVectorsScorer; + final float[] scores = new float[BULK_SIZE]; + final float[] correctionsLower = new float[BULK_SIZE]; + final float[] correctionsUpper = new float[BULK_SIZE]; + final int[] correctionsSum = new int[BULK_SIZE]; + final float[] correctionsAdd = new float[BULK_SIZE]; + + int[] docIdsScratch = new int[0]; + int vectors; + boolean quantized = false; + float centroidDp; + float[] centroid; + long slicePos; + OptimizedScalarQuantizer.QuantizationResult queryCorrections; + DocIdsWriter docIdsWriter = new DocIdsWriter(); + + final float[] scratch; + final byte[] quantizationScratch; + final byte[] quantizedQueryScratch; + final OptimizedScalarQuantizer quantizer; + final float[] correctiveValues = new float[3]; + final long quantizedVectorByteSize; + + MemorySegmentPostingsVisitor( + float[] target, + IndexInput indexInput, + FieldEntry entry, + FieldInfo fieldInfo, + IntPredicate needsScoring + ) throws IOException { + this.target = target; + this.indexInput = indexInput; + this.entry = entry; + this.fieldInfo = fieldInfo; + this.needsScoring = needsScoring; + + scratch = new float[target.length]; + quantizationScratch = new byte[target.length]; + final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); + quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; + quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; + quantizedVectorByteSize = (discretizedDimensions / 8); + quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); } - if (docsToScore == 0) { - continue; + + @Override + public int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException { + quantized = false; + indexInput.seek(entry.postingListOffsets()[centroidOrdinal]); + vectors = indexInput.readVInt(); + centroidDp = Float.intBitsToFloat(indexInput.readInt()); + this.centroid = centroid; + // read the doc ids + docIdsScratch = vectors > docIdsScratch.length ? new int[vectors] : docIdsScratch; + docIdsWriter.readInts(indexInput, vectors, docIdsScratch); + slicePos = indexInput.getFilePointer(); + return vectors; } - quantizeQueryIfNecessary(); - indexInput.seek(slicePos + i * quantizedByteLength); - if (docsToScore < BULK_SIZE / 2) { - scoreIndividually(i); - } else { - osqVectorsScorer.scoreBulk( - quantizedQueryScratch, - queryCorrections, - fieldInfo.getVectorSimilarityFunction(), - centroidDp, - scores); + + void scoreIndividually(int offset) throws IOException { + // score individually, first the quantized byte chunk + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j + offset]; + if (doc != -1) { + indexInput.seek(slicePos + (offset * quantizedByteLength) + (j * quantizedVectorByteSize)); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + scores[j] = qcDist; + } + } + // read in all corrections + indexInput.seek(slicePos + (offset * quantizedByteLength) + (BULK_SIZE * quantizedVectorByteSize)); + indexInput.readFloats(correctionsLower, 0, BULK_SIZE); + indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); + for (int j = 0; j < BULK_SIZE; j++) { + correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort()); + } + indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); + // Now apply corrections + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[offset + j]; + if (doc != -1) { + scores[j] = osqVectorsScorer.score( + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctionsLower[j], + correctionsUpper[j], + correctionsSum[j], + correctionsAdd[j], + scores[j] + ); + } + } } - for (int j = 0; j < BULK_SIZE; j++) { - int doc = docIdsScratch[i + j]; - if (doc != -1) { - scoredDocs++; - knnCollector.collect(doc, scores[j]); - } + + @Override + public int visit(KnnCollector knnCollector) throws IOException { + // block processing + int scoredDocs = 0; + int limit = vectors - BULK_SIZE + 1; + int i = 0; + for (; i < limit; i += BULK_SIZE) { + int docsToScore = BULK_SIZE; + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[i + j]; + if (needsScoring.test(doc) == false) { + docIdsScratch[i + j] = -1; + docsToScore--; + } + } + if (docsToScore == 0) { + continue; + } + quantizeQueryIfNecessary(); + indexInput.seek(slicePos + i * quantizedByteLength); + if (docsToScore < BULK_SIZE / 2) { + scoreIndividually(i); + } else { + osqVectorsScorer.scoreBulk( + quantizedQueryScratch, + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + scores + ); + } + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[i + j]; + if (doc != -1) { + scoredDocs++; + knnCollector.collect(doc, scores[j]); + } + } + } + // process tail + for (; i < vectors; i++) { + int doc = docIdsScratch[i]; + if (needsScoring.test(doc)) { + quantizeQueryIfNecessary(); + indexInput.seek(slicePos + i * quantizedByteLength); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + indexInput.readFloats(correctiveValues, 0, 3); + final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); + float score = osqVectorsScorer.score( + queryCorrections, + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist + ); + scoredDocs++; + knnCollector.collect(doc, score); + } + } + knnCollector.incVisitedCount(scoredDocs); + return scoredDocs; } - } - // process tail - for (; i < vectors; i++) { - int doc = docIdsScratch[i]; - if (needsScoring.test(doc)) { - quantizeQueryIfNecessary(); - indexInput.seek(slicePos + i * quantizedByteLength); - float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); - indexInput.readFloats(correctiveValues, 0, 3); - final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); - float score = - osqVectorsScorer.score( - queryCorrections, - fieldInfo.getVectorSimilarityFunction(), - centroidDp, - correctiveValues[0], - correctiveValues[1], - quantizedComponentSum, - correctiveValues[2], - qcDist); - scoredDocs++; - knnCollector.collect(doc, score); + + private void quantizeQueryIfNecessary() { + if (quantized == false) { + System.arraycopy(target, 0, scratch, 0, target.length); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(scratch); + } + queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid); + transposeHalfByte(quantizationScratch, quantizedQueryScratch); + quantized = true; + } } - } - knnCollector.incVisitedCount(scoredDocs); - return scoredDocs; } - private void quantizeQueryIfNecessary() { - if (quantized == false) { - System.arraycopy(target, 0, scratch, 0, target.length); - if (fieldInfo.getVectorSimilarityFunction() == COSINE) { - VectorUtil.l2normalize(scratch); - } - queryCorrections = - quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid); - transposeHalfByte(quantizationScratch, quantizedQueryScratch); - quantized = true; - } + static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) throws IOException { + assert corrections.length == 3; + indexInput.readBytes(binaryValue, 0, binaryValue.length); + corrections[0] = Float.intBitsToFloat(indexInput.readInt()); + corrections[1] = Float.intBitsToFloat(indexInput.readInt()); + corrections[2] = Float.intBitsToFloat(indexInput.readInt()); + return Short.toUnsignedInt(indexInput.readShort()); } - } - - static int readQuantizedValue(IndexInput indexInput, byte[] binaryValue, float[] corrections) - throws IOException { - assert corrections.length == 3; - indexInput.readBytes(binaryValue, 0, binaryValue.length); - corrections[0] = Float.intBitsToFloat(indexInput.readInt()); - corrections[1] = Float.intBitsToFloat(indexInput.readInt()); - corrections[2] = Float.intBitsToFloat(indexInput.readInt()); - return Short.toUnsignedInt(indexInput.readShort()); - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 2a5b66544fbb7..45fdaa9c3dec0 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -9,18 +9,6 @@ package org.elasticsearch.index.codec.vectors; -import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; -import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; -import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.List; -import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.FloatVectorValues; @@ -30,13 +18,23 @@ import org.apache.lucene.internal.hppc.IntArrayList; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.FixedBitSet; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.ESVectorUtil; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.INDEX_BITS; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.discretize; +import static org.apache.lucene.util.quantization.OptimizedScalarQuantizer.packAsBinary; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.IVF_VECTOR_COMPONENT; + /** * Default implementation of {@link IVFVectorsWriter}. It uses {@link KMeans} algorithm to * partition the vector space, and then stores the centroids an posting list in a sequential @@ -44,44 +42,34 @@ */ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { - static final float SOAR_LAMBDA = 1.0f; - // What percentage of the centroids do we do a second check on for SOAR assignment - static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f; - - private final int vectorPerCluster; - - private final OptimizedScalarQuantizer.QuantizationResult[] corrections = - new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE]; - - public DefaultIVFVectorsWriter( - SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) - throws IOException { - super(state, rawVectorDelegate); - this.vectorPerCluster = vectorPerCluster; - } - - @Override - CentroidAssignmentScorer calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid) - throws IOException { - if (floatVectorValues.size() == 0) { - return CentroidAssignmentScorer.EMPTY; - } - // calculate the centroids - int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - int desiredClusters = - (int) - Math.max( - maxNumClusters / 16.0, - Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters)); - if (floatVectorValues.size() / desiredClusters > vectorPerCluster) { - desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + static final float SOAR_LAMBDA = 1.0f; + // What percentage of the centroids do we do a second check on for SOAR assignment + static final float EXT_SOAR_LIMIT_CHECK_RATIO = 0.10f; + + private final int vectorPerCluster; + + public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException { + super(state, rawVectorDelegate); + this.vectorPerCluster = vectorPerCluster; } - final KMeans.Results kMeans = - KMeans.cluster( + + @Override + CentroidAssignmentScorer calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + float[] globalCentroid + ) throws IOException { + if (floatVectorValues.size() == 0) { + return CentroidAssignmentScorer.EMPTY; + } + // calculate the centroids + int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + int desiredClusters = (int) Math.max(maxNumClusters / 16.0, Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters)); + if (floatVectorValues.size() / desiredClusters > vectorPerCluster) { + desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + } + final KMeans.Results kMeans = KMeans.cluster( floatVectorValues, desiredClusters, false, @@ -91,291 +79,282 @@ CentroidAssignmentScorer calculateAndWriteCentroids( fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, 1, 15, - desiredClusters * 256); - float[][] centroids = kMeans.centroids(); - // write them - writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); - return new OnHeapCentroidAssignmentScorer(centroids); - } - - @Override - long[] buildAndWritePostingsLists( - FieldInfo fieldInfo, - InfoStream infoStream, - CentroidAssignmentScorer randomCentroidScorer, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput) - throws IOException { - IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()]; - for (int i = 0; i < randomCentroidScorer.size(); i++) { - clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4); - } - assignCentroids(randomCentroidScorer, floatVectorValues, clusters); - if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - printClusterQualityStatistics(clusters, infoStream); - } - // write the posting lists - final long[] offsets = new long[randomCentroidScorer.size()]; - OptimizedScalarQuantizer quantizer = - new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - BinarizedFloatVectorValues binarizedByteVectorValues = - new BinarizedFloatVectorValues(floatVectorValues, quantizer); - DocIdsWriter docIdsWriter = new DocIdsWriter(); - for (int i = 0; i < randomCentroidScorer.size(); i++) { - float[] centroid = randomCentroidScorer.centroid(i); - binarizedByteVectorValues.centroid = centroid; - // TODO sort by distance to the centroid - IntArrayList cluster = clusters[i]; - // TODO align??? - offsets[i] = postingsOutput.getFilePointer(); - int size = cluster.size(); - postingsOutput.writeVInt(size); - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); - // TODO we might want to consider putting the docIds in a separate file - // to aid with only having to fetch vectors from slower storage when they are required - // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds( - j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput); - writePostingList(cluster, postingsOutput, binarizedByteVectorValues); - } - return offsets; - } - - private void writePostingList( - IntArrayList cluster, - IndexOutput postingsOutput, - BinarizedFloatVectorValues binarizedByteVectorValues) - throws IOException { - int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1; - int cidx = 0; - // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE. - for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) { - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - int ord = cluster.get(cidx + j); - byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); - // write vector - postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); - corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord); - } - // write corrections - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval())); - } - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval())); - } - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - int targetComponentSum = corrections[j].quantizedComponentSum(); - assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; - postingsOutput.writeShort((short) targetComponentSum); - } - for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { - postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection())); - } - } - // write tail - for (; cidx < cluster.size(); cidx++) { - int ord = cluster.get(cidx); - // write vector - byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); - OptimizedScalarQuantizer.QuantizationResult corrections = - binarizedByteVectorValues.getCorrectiveTerms(ord); - writeQuantizedValue(postingsOutput, binaryValue, corrections); - binarizedByteVectorValues.getCorrectiveTerms(ord); - postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); - postingsOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); - postingsOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); - postingsOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); - assert corrections.quantizedComponentSum() >= 0 - && corrections.quantizedComponentSum() <= 0xffff; - postingsOutput.writeShort((short) corrections.quantizedComponentSum()); - } - } - - @Override - CentroidAssignmentScorer createCentroidScorer( - IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) - throws IOException { - return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo); - } - - static void writeCentroids( - float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) - throws IOException { - final OptimizedScalarQuantizer osq = - new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()]; - float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; - // TODO do we want to store these distances as well for future use? - float[] distances = new float[centroids.length]; - for (int i = 0; i < centroids.length; i++) { - distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid); - } - // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest - // (largest) - for (int i = 0; i < centroids.length; i++) { - for (int j = i + 1; j < centroids.length; j++) { - if (distances[i] > distances[j]) { - float[] tmp = centroids[i]; - centroids[i] = centroids[j]; - centroids[j] = tmp; - float tmpDistance = distances[i]; - distances[i] = distances[j]; - distances[j] = tmpDistance; - } - } - } - for (float[] centroid : centroids) { - System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); - OptimizedScalarQuantizer.QuantizationResult result = - osq.scalarQuantize(centroidScratch, quantizedScratch, (byte) 4, globalCentroid); - writeQuantizedValue(centroidOutput, quantizedScratch, result); - } - final ByteBuffer buffer = - ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES) - .order(ByteOrder.LITTLE_ENDIAN); - for (float[] centroid : centroids) { - buffer.asFloatBuffer().put(centroid); - centroidOutput.writeBytes(buffer.array(), buffer.array().length); - } - } - - record SegmentCentroid(int segment, int centroid, int centroidSize) {} - - @Override - protected int calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput temporaryCentroidOutput, - MergeState mergeState, - float[] globalCentroid) - throws IOException { - if (floatVectorValues.size() == 0) { - return 0; - } - int desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - // init centroids from merge state - List centroidList = new ArrayList<>(); - List segmentCentroids = new ArrayList<>(desiredClusters); - - int segmentIdx = 0; - long startTime = System.nanoTime(); - for (var reader : mergeState.knnVectorsReaders) { - IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); - if (ivfVectorsReader == null) { - continue; - } - - FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); - centroidList.add(centroid); - for (int i = 0; i < centroid.size(); i++) { - int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); - segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); - } - segmentIdx++; + desiredClusters * 256 + ); + float[][] centroids = kMeans.centroids(); + // write them + writeCentroids(centroids, fieldInfo, globalCentroid, centroidOutput); + return new OnHeapCentroidAssignmentScorer(centroids); } - // sort centroid list by floatvector size - FloatVectorValues baseSegment = centroidList.get(0); - for (var l : centroidList) { - if (l.size() > baseSegment.size()) { - baseSegment = l; - } - } - float[] scratch = new float[fieldInfo.getVectorDimension()]; - float minimumDistance = Float.MAX_VALUE; - for (int j = 0; j < baseSegment.size(); j++) { - System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension()); - for (int k = j + 1; k < baseSegment.size(); k++) { - float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k)); - if (d < minimumDistance) { - minimumDistance = d; - } - } + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + InfoStream infoStream, + CentroidAssignmentScorer randomCentroidScorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput + ) throws IOException { + IntArrayList[] clusters = new IntArrayList[randomCentroidScorer.size()]; + for (int i = 0; i < randomCentroidScorer.size(); i++) { + clusters[i] = new IntArrayList(floatVectorValues.size() / randomCentroidScorer.size() / 4); + } + assignCentroids(randomCentroidScorer, floatVectorValues, clusters); + if (infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(clusters, infoStream); + } + // write the posting lists + final long[] offsets = new long[randomCentroidScorer.size()]; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + for (int i = 0; i < randomCentroidScorer.size(); i++) { + float[] centroid = randomCentroidScorer.centroid(i); + binarizedByteVectorValues.centroid = centroid; + // TODO sort by distance to the centroid + IntArrayList cluster = clusters[i]; + // TODO align??? + offsets[i] = postingsOutput.getFilePointer(); + int size = cluster.size(); + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), cluster.size(), postingsOutput); + writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + } + return offsets; + } + + private void writePostingList(IntArrayList cluster, IndexOutput postingsOutput, BinarizedFloatVectorValues binarizedByteVectorValues) + throws IOException { + int limit = cluster.size() - ES91OSQVectorsScorer.BULK_SIZE + 1; + int cidx = 0; + OptimizedScalarQuantizer.QuantizationResult[] corrections = + new OptimizedScalarQuantizer.QuantizationResult[ES91OSQVectorsScorer.BULK_SIZE]; + // Write vectors in bulks of ES91OSQVectorsScorer.BULK_SIZE. + for (; cidx < limit; cidx += ES91OSQVectorsScorer.BULK_SIZE) { + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + int ord = cluster.get(cidx + j); + byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); + // write vector + postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); + corrections[j] = binarizedByteVectorValues.getCorrectiveTerms(ord); + } + // write corrections + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].lowerInterval())); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].upperInterval())); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + int targetComponentSum = corrections[j].quantizedComponentSum(); + assert targetComponentSum >= 0 && targetComponentSum <= 0xffff; + postingsOutput.writeShort((short) targetComponentSum); + } + for (int j = 0; j < ES91OSQVectorsScorer.BULK_SIZE; j++) { + postingsOutput.writeInt(Float.floatToIntBits(corrections[j].additionalCorrection())); + } + } + // write tail + for (; cidx < cluster.size(); cidx++) { + int ord = cluster.get(cidx); + // write vector + byte[] binaryValue = binarizedByteVectorValues.vectorValue(ord); + OptimizedScalarQuantizer.QuantizationResult correction = binarizedByteVectorValues.getCorrectiveTerms(ord); + writeQuantizedValue(postingsOutput, binaryValue, correction); + binarizedByteVectorValues.getCorrectiveTerms(ord); + postingsOutput.writeBytes(binaryValue, 0, binaryValue.length); + postingsOutput.writeInt(Float.floatToIntBits(correction.lowerInterval())); + postingsOutput.writeInt(Float.floatToIntBits(correction.upperInterval())); + postingsOutput.writeInt(Float.floatToIntBits(correction.additionalCorrection())); + assert correction.quantizedComponentSum() >= 0 && correction.quantizedComponentSum() <= 0xffff; + postingsOutput.writeShort((short) correction.quantizedComponentSum()); + } } - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "Agglomerative cluster min distance: " - + minimumDistance - + " From biggest segment: " - + baseSegment.size()); + + @Override + CentroidAssignmentScorer createCentroidScorer( + IndexInput centroidsInput, + int numCentroids, + FieldInfo fieldInfo, + float[] globalCentroid + ) { + return new OffHeapCentroidAssignmentScorer(centroidsInput, numCentroids, fieldInfo); + } + + static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput) + throws IOException { + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()]; + float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; + // TODO do we want to store these distances as well for future use? + float[] distances = new float[centroids.length]; + for (int i = 0; i < centroids.length; i++) { + distances[i] = VectorUtil.squareDistance(centroids[i], globalCentroid); + } + // sort the centroids by distance to globalCentroid, nearest (smallest distance), to furthest + // (largest) + for (int i = 0; i < centroids.length; i++) { + for (int j = i + 1; j < centroids.length; j++) { + if (distances[i] > distances[j]) { + float[] tmp = centroids[i]; + centroids[i] = centroids[j]; + centroids[j] = tmp; + float tmpDistance = distances[i]; + distances[i] = distances[j]; + distances[j] = tmpDistance; + } + } + } + for (float[] centroid : centroids) { + System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); + OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize( + centroidScratch, + quantizedScratch, + (byte) 4, + globalCentroid + ); + writeQuantizedValue(centroidOutput, quantizedScratch, result); + } + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (float[] centroid : centroids) { + buffer.asFloatBuffer().put(centroid); + centroidOutput.writeBytes(buffer.array(), buffer.array().length); + } } - int[] labels = new int[segmentCentroids.size()]; - // loop over segments - int clusterIdx = 0; - // keep track of all inter-centroid distances, - // using less than centroid * centroid space (e.g. not keeping track of duplicates) - for (int i = 0; i < segmentCentroids.size(); i++) { - if (labels[i] == 0) { - clusterIdx += 1; - labels[i] = clusterIdx; - } - SegmentCentroid segmentCentroid = segmentCentroids.get(i); - System.arraycopy( - centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid), - 0, - scratch, - 0, - baseSegment.dimension()); - for (int j = i + 1; j < segmentCentroids.size(); j++) { - float d = - VectorUtil.squareDistance( + + record SegmentCentroid(int segment, int centroid, int centroidSize) {} + + @Override + protected int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid + ) throws IOException { + if (floatVectorValues.size() == 0) { + return 0; + } + int desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + // init centroids from merge state + List centroidList = new ArrayList<>(); + List segmentCentroids = new ArrayList<>(desiredClusters); + + int segmentIdx = 0; + long startTime = System.nanoTime(); + for (var reader : mergeState.knnVectorsReaders) { + IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); + if (ivfVectorsReader == null) { + continue; + } + + FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); + if (centroid == null) { + continue; + } + centroidList.add(centroid); + for (int i = 0; i < centroid.size(); i++) { + int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); + segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); + } + segmentIdx++; + } + + // sort centroid list by floatvector size + FloatVectorValues baseSegment = centroidList.get(0); + for (var l : centroidList) { + if (l.size() > baseSegment.size()) { + baseSegment = l; + } + } + float[] scratch = new float[fieldInfo.getVectorDimension()]; + float minimumDistance = Float.MAX_VALUE; + for (int j = 0; j < baseSegment.size(); j++) { + System.arraycopy(baseSegment.vectorValue(j), 0, scratch, 0, baseSegment.dimension()); + for (int k = j + 1; k < baseSegment.size(); k++) { + float d = VectorUtil.squareDistance(scratch, baseSegment.vectorValue(k)); + if (d < minimumDistance) { + minimumDistance = d; + } + } + } + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Agglomerative cluster min distance: " + minimumDistance + " From biggest segment: " + baseSegment.size() + ); + } + int[] labels = new int[segmentCentroids.size()]; + // loop over segments + int clusterIdx = 0; + // keep track of all inter-centroid distances, + // using less than centroid * centroid space (e.g. not keeping track of duplicates) + for (int i = 0; i < segmentCentroids.size(); i++) { + if (labels[i] == 0) { + clusterIdx += 1; + labels[i] = clusterIdx; + } + SegmentCentroid segmentCentroid = segmentCentroids.get(i); + System.arraycopy( + centroidList.get(segmentCentroid.segment()).vectorValue(segmentCentroid.centroid), + 0, scratch, - centroidList - .get(segmentCentroids.get(j).segment()) - .vectorValue(segmentCentroids.get(j).centroid())); - if (d < minimumDistance / 2) { - if (labels[j] == 0) { - labels[j] = labels[i]; - } else { - for (int k = 0; k < labels.length; k++) { - if (labels[k] == labels[j]) { - labels[k] = labels[i]; - } + 0, + baseSegment.dimension() + ); + for (int j = i + 1; j < segmentCentroids.size(); j++) { + float d = VectorUtil.squareDistance( + scratch, + centroidList.get(segmentCentroids.get(j).segment()).vectorValue(segmentCentroids.get(j).centroid()) + ); + if (d < minimumDistance / 2) { + if (labels[j] == 0) { + labels[j] = labels[i]; + } else { + for (int k = 0; k < labels.length; k++) { + if (labels[k] == labels[j]) { + labels[k] = labels[i]; + } + } + } + } } - } } - } - } - float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()]; - int[] sum = new int[clusterIdx]; - for (int i = 0; i < segmentCentroids.size(); i++) { - SegmentCentroid segmentCentroid = segmentCentroids.get(i); - int label = labels[i]; - FloatVectorValues segment = centroidList.get(segmentCentroid.segment()); - float[] vector = segment.vectorValue(segmentCentroid.centroid); - for (int j = 0; j < vector.length; j++) { - initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize); - } - sum[label - 1] += segmentCentroid.centroidSize; - } - for (int i = 0; i < initCentroids.length; i++) { - for (int j = 0; j < initCentroids[i].length; j++) { - initCentroids[i][j] /= sum[i]; - } - } - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0)); - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters); - } + float[][] initCentroids = new float[clusterIdx][fieldInfo.getVectorDimension()]; + int[] sum = new int[clusterIdx]; + for (int i = 0; i < segmentCentroids.size(); i++) { + SegmentCentroid segmentCentroid = segmentCentroids.get(i); + int label = labels[i]; + FloatVectorValues segment = centroidList.get(segmentCentroid.segment()); + float[] vector = segment.vectorValue(segmentCentroid.centroid); + for (int j = 0; j < vector.length; j++) { + initCentroids[label - 1][j] += (vector[j] * segmentCentroid.centroidSize); + } + sum[label - 1] += segmentCentroid.centroidSize; + } + for (int i = 0; i < initCentroids.length; i++) { + for (int j = 0; j < initCentroids[i].length; j++) { + initCentroids[i][j] /= sum[i]; + } + } + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Agglomerative cluster time ms: " + ((System.nanoTime() - startTime) / 1000000.0) + ); + mergeState.infoStream.message( + IVF_VECTOR_COMPONENT, + "Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters + ); + } - // FIXME: still split to get to desired cluster count? - // FIXME: need a way to maintain the original mapping ... update KMeans to allow maintaining - // that mapping - // FIXME: go update the assignCentroids code to respect that mapping from prior centroid to next - // centroid (via the scorer?) - // FIXME: run a custom version of kmeans that adjusts the centroids that were split related to - // only the sets of vectors that were previously associated with the prior centroids - // FIXME: compare this kmeans outcome with a lot of iterations with the outcome of the process - // detailed above; ideally a large run of kmeans is approximated by the above algorithm - long nanoTime = System.nanoTime(); - final KMeans.Results kMeans = - KMeans.cluster( + // FIXME: run a custom version of KMeans that is just better... + long nanoTime = System.nanoTime(); + final KMeans.Results kMeans = KMeans.cluster( floatVectorValues, desiredClusters, false, @@ -385,507 +364,383 @@ protected int calculateAndWriteCentroids( fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE, 1, 5, - desiredClusters * 64); - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); - } - float[][] centroids = kMeans.centroids(); - - // write them - writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput); - return centroids.length; - } - - @Override - long[] buildAndWritePostingsLists( - FieldInfo fieldInfo, - CentroidAssignmentScorer centroidAssignmentScorer, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput, - MergeState mergeState) - throws IOException { - IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()]; - for (int i = 0; i < centroidAssignmentScorer.size(); i++) { - clusters[i] = - new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4); - } - long nanoTime = System.nanoTime(); - assignCentroidsMerge( - centroidAssignmentScorer, floatVectorValues, mergeState, fieldInfo.name, clusters); - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - mergeState.infoStream.message( - IVF_VECTOR_COMPONENT, - "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); - } + desiredClusters * 64 + ); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "KMeans time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); + } + float[][] centroids = kMeans.centroids(); - if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { - printClusterQualityStatistics(clusters, mergeState.infoStream); - } - // write the posting lists - final long[] offsets = new long[centroidAssignmentScorer.size()]; - OptimizedScalarQuantizer quantizer = - new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - BinarizedFloatVectorValues binarizedByteVectorValues = - new BinarizedFloatVectorValues(floatVectorValues, quantizer); - DocIdsWriter docIdsWriter = new DocIdsWriter(); - for (int i = 0; i < centroidAssignmentScorer.size(); i++) { - float[] centroid = centroidAssignmentScorer.centroid(i); - binarizedByteVectorValues.centroid = centroid; - // TODO: sort by distance to the centroid - IntArrayList cluster = clusters[i]; - // TODO align??? - offsets[i] = postingsOutput.getFilePointer(); - int size = cluster.size(); - postingsOutput.writeVInt(size); - postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); - // TODO we might want to consider putting the docIds in a separate file - // to aid with only having to fetch vectors from slower storage when they are required - // keeping them in the same file indicates we pull the entire file into cache - docIdsWriter.writeDocIds( - j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); - writePostingList(cluster, postingsOutput, binarizedByteVectorValues); - } - return offsets; - } - - private static void printClusterQualityStatistics( - IntArrayList[] clusters, InfoStream infoStream) { - float min = Float.MAX_VALUE; - float max = Float.MIN_VALUE; - float mean = 0; - float m2 = 0; - // iteratively compute the variance & mean - int count = 0; - for (IntArrayList cluster : clusters) { - count += 1; - if (cluster == null) { - continue; - } - float delta = cluster.size() - mean; - mean += delta / count; - m2 += delta * (cluster.size() - mean); - min = Math.min(min, cluster.size()); - max = Math.max(max, cluster.size()); - } - float variance = m2 / (clusters.length - 1); - infoStream.message( - IVF_VECTOR_COMPONENT, - "Centroid count: " - + clusters.length - + " min: " - + min - + " max: " - + max - + " mean: " - + mean - + " stdDev: " - + Math.sqrt(variance) - + " variance: " - + variance); - } - - static void assignCentroids( - CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) - throws IOException { - short numCentroids = (short) scorer.size(); - // If soar > 0, then we actually need to apply the projection, otherwise, its just the second - // nearest centroid - // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible - int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); - int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); - // if lambda is `0`, that just means overspill to the second nearest, so we will only check the - // second nearest - if (SOAR_LAMBDA == 0) { - soarClusterCheckCount = Math.min(1, soarClusterCheckCount); - } - NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); - OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); - float[] scratch = new float[vectors.dimension()]; - for (int docID = 0; docID < vectors.size(); docID++) { - float[] vector = vectors.vectorValue(docID); - scorer.setScoringVector(vector); - int bestCentroid = 0; - float bestScore = Float.MAX_VALUE; - if (numCentroids > 1) { - for (short c = 0; c < numCentroids; c++) { - float squareDist = scorer.score(c); - neighborsToCheck.insertWithOverflow(c, squareDist); - } - // pop the best - int sz = neighborsToCheck.size(); - int best = - neighborsToCheck.consumeNodesAndScoresMin( - ordScoreIterator.ords, ordScoreIterator.scores); - // TODO yikes.... - ordScoreIterator.idx = sz; - bestScore = ordScoreIterator.getScore(best); - bestCentroid = ordScoreIterator.getOrd(best); - } - if (clusters[bestCentroid] == null) { - clusters[bestCentroid] = new IntArrayList(16); - } - clusters[bestCentroid].add(docID); - if (soarClusterCheckCount > 0) { - assignCentroidSOAR( - ordScoreIterator, - docID, - bestCentroid, - scorer.centroid(bestCentroid), - bestScore, - scratch, - scorer, - vectors, - clusters); - } - neighborsToCheck.clear(); - } - } - - static int prefilterCentroidAssignment( - int centroidOrd, - FloatVectorValues segmentCentroids, - CentroidAssignmentScorer scorer, - NeighborQueue neighborsToCheck, - int[] prefilteredCentroids) - throws IOException { - float[] segmentCentroid = segmentCentroids.vectorValue(centroidOrd); - scorer.setScoringVector(segmentCentroid); - neighborsToCheck.clear(); - for (short c = 0; c < scorer.size(); c++) { - float squareDist = scorer.score(c); - neighborsToCheck.insertWithOverflow(c, squareDist); + // write them + writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput); + return centroids.length; } - int size = neighborsToCheck.size(); - neighborsToCheck.consumeNodes(prefilteredCentroids); - return size; - } - - static void assignCentroidsMerge( - CentroidAssignmentScorer scorer, - FloatVectorValues vectors, - MergeState state, - String fieldName, - IntArrayList[] clusters) - throws IOException { - FixedBitSet assigned = new FixedBitSet(vectors.size() + 1); - short numCentroids = (short) scorer.size(); - // If soar > 0, then we actually need to apply the projection, otherwise, its just the second - // nearest centroid - // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible - int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); - int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); - // TODO is this the right to check? - // If cluster quality is higher, maybe we can reduce this... - int prefilteredCentroidCount = - Math.max(soarClusterCheckCount + 1, numCentroids / state.knnVectorsReaders.length); - NeighborQueue prefilteredCentroidsToCheck = new NeighborQueue(prefilteredCentroidCount, true); - NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); - OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); - int[] prefilteredCentroids = new int[prefilteredCentroidCount]; - float[] scratch = new float[vectors.dimension()]; - // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? - for (int idx = 0; idx < state.knnVectorsReaders.length; idx++) { - KnnVectorsReader reader = state.knnVectorsReaders[idx]; - IVFVectorsReader vectorsReader = getIVFReader(reader, fieldName); - // No reader, skip - if (vectorsReader == null) { - continue; - } - MergeState.DocMap docMap = state.docMaps[idx]; - var segmentCentroids = vectorsReader.getCentroids(state.fieldInfos[idx].fieldInfo(fieldName)); - for (int i = 0; i < segmentCentroids.size(); i++) { - IVFVectorsReader.CentroidInfo info = vectorsReader.centroidVectors(fieldName, i, docMap); - // Rare, but empty centroid, no point in doing comparisons - if (info.vectors().size == 0) { - continue; - } - prefilteredCentroidsToCheck.clear(); - int prefiltedCount = - prefilterCentroidAssignment( - i, segmentCentroids, scorer, prefilteredCentroidsToCheck, prefilteredCentroids); - int centroidVectorDocId = -1; - while ((centroidVectorDocId = info.vectors().nextVectorDocId()) != NO_MORE_DOCS) { - if (assigned.getAndSet(centroidVectorDocId)) { - continue; - } - neighborsToCheck.clear(); - float[] vector = info.vectors().vectorValue(); - scorer.setScoringVector(vector); - int bestCentroid; - float bestScore; - for (int c = 0; c < prefiltedCount; c++) { - float squareDist = scorer.score(prefilteredCentroids[c]); - neighborsToCheck.insertWithOverflow(prefilteredCentroids[c], squareDist); - } - int centroidCount = neighborsToCheck.size(); - int best = - neighborsToCheck.consumeNodesAndScoresMin( - ordScoreIterator.ords, ordScoreIterator.scores); - // yikes - ordScoreIterator.idx = centroidCount; - bestScore = ordScoreIterator.getScore(best); - bestCentroid = ordScoreIterator.getOrd(best); - if (clusters[bestCentroid] == null) { - clusters[bestCentroid] = new IntArrayList(16); - } - clusters[bestCentroid].add(info.vectors().docId()); - if (soarClusterCheckCount > 0) { - assignCentroidSOAR( - ordScoreIterator, - info.vectors().docId(), - bestCentroid, - scorer.centroid(bestCentroid), - bestScore, - scratch, - scorer, - vectors, - clusters); - } + + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidAssignmentScorer centroidAssignmentScorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState + ) throws IOException { + IntArrayList[] clusters = new IntArrayList[centroidAssignmentScorer.size()]; + for (int i = 0; i < centroidAssignmentScorer.size(); i++) { + clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4); + } + long nanoTime = System.nanoTime(); + assignCentroidsMerge(centroidAssignmentScorer, floatVectorValues, clusters); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); } - } - } - for (int vecOrd = 0; vecOrd < vectors.size(); vecOrd++) { - if (assigned.get(vecOrd)) { - continue; - } - float[] vector = vectors.vectorValue(vecOrd); - scorer.setScoringVector(vector); - int bestCentroid = 0; - float bestScore = Float.MAX_VALUE; - if (numCentroids > 1) { - for (short c = 0; c < numCentroids; c++) { - float squareDist = scorer.score(c); - neighborsToCheck.insertWithOverflow(c, squareDist); - } - int centroidCount = neighborsToCheck.size(); - int bestIdx = - neighborsToCheck.consumeNodesAndScoresMin( - ordScoreIterator.ords, ordScoreIterator.scores); - ordScoreIterator.idx = centroidCount; - bestCentroid = ordScoreIterator.getOrd(bestIdx); - bestScore = ordScoreIterator.getScore(bestIdx); - } - if (clusters[bestCentroid] == null) { - clusters[bestCentroid] = new IntArrayList(16); - } - int docID = vectors.ordToDoc(vecOrd); - clusters[bestCentroid].add(docID); - if (soarClusterCheckCount > 0) { - assignCentroidSOAR( - ordScoreIterator, - docID, - bestCentroid, - scorer.centroid(bestCentroid), - bestScore, - scratch, - scorer, - vectors, - clusters); - } - neighborsToCheck.clear(); + if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { + printClusterQualityStatistics(clusters, mergeState.infoStream); + } + // write the posting lists + final long[] offsets = new long[centroidAssignmentScorer.size()]; + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + BinarizedFloatVectorValues binarizedByteVectorValues = new BinarizedFloatVectorValues(floatVectorValues, quantizer); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + for (int i = 0; i < centroidAssignmentScorer.size(); i++) { + float[] centroid = centroidAssignmentScorer.centroid(i); + binarizedByteVectorValues.centroid = centroid; + // TODO: sort by distance to the centroid + IntArrayList cluster = clusters[i]; + // TODO align??? + offsets[i] = postingsOutput.getFilePointer(); + int size = cluster.size(); + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); + writePostingList(cluster, postingsOutput, binarizedByteVectorValues); + } + return offsets; + } + + private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoStream infoStream) { + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + float mean = 0; + float m2 = 0; + // iteratively compute the variance & mean + int count = 0; + for (IntArrayList cluster : clusters) { + count += 1; + if (cluster == null) { + continue; + } + float delta = cluster.size() - mean; + mean += delta / count; + m2 += delta * (cluster.size() - mean); + min = Math.min(min, cluster.size()); + max = Math.max(max, cluster.size()); + } + float variance = m2 / (clusters.length - 1); + infoStream.message( + IVF_VECTOR_COMPONENT, + "Centroid count: " + + clusters.length + + " min: " + + min + + " max: " + + max + + " mean: " + + mean + + " stdDev: " + + Math.sqrt(variance) + + " variance: " + + variance + ); + } + + static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException { + short numCentroids = (short) scorer.size(); + // If soar > 0, then we actually need to apply the projection, otherwise, its just the second + // nearest centroid + // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible + int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); + int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); + // if lambda is `0`, that just means overspill to the second nearest, so we will only check the + // second nearest + NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); + OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); + float[] scratch = new float[vectors.dimension()]; + for (int docID = 0; docID < vectors.size(); docID++) { + float[] vector = vectors.vectorValue(docID); + scorer.setScoringVector(vector); + int bestCentroid = 0; + float bestScore = Float.MAX_VALUE; + if (numCentroids > 1) { + for (short c = 0; c < numCentroids; c++) { + float squareDist = scorer.score(c); + neighborsToCheck.insertWithOverflow(c, squareDist); + } + // pop the best + int sz = neighborsToCheck.size(); + int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); + // TODO yikes.... + ordScoreIterator.idx = sz; + bestScore = ordScoreIterator.getScore(best); + bestCentroid = ordScoreIterator.getOrd(best); + } + if (clusters[bestCentroid] == null) { + clusters[bestCentroid] = new IntArrayList(16); + } + clusters[bestCentroid].add(docID); + if (soarClusterCheckCount > 0) { + assignCentroidSOAR( + ordScoreIterator, + docID, + bestCentroid, + scorer.centroid(bestCentroid), + bestScore, + scratch, + scorer, + vectors, + clusters + ); + } + neighborsToCheck.clear(); + } } - } - - static void assignCentroidSOAR( - OrdScoreIterator centroidsToCheck, - int docId, - int bestCentroidId, - float[] bestCentroid, - float bestScore, - float[] scratch, - CentroidAssignmentScorer scorer, - FloatVectorValues vectors, - IntArrayList[] clusters) - throws IOException { - float[] vector = vectors.vectorValue(docId); - ESVectorUtil.subtract(vector, bestCentroid, scratch); - int bestSecondaryCentroid = -1; - float minDist = Float.MAX_VALUE; - for (int i = 0; i < centroidsToCheck.size(); i++) { - float score = centroidsToCheck.getScore(i); - int centroidOrdinal = centroidsToCheck.getOrd(i); - if (centroidOrdinal == bestCentroidId) { - continue; - } - if (SOAR_LAMBDA > 0) { - float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch); - score += SOAR_LAMBDA * proj * proj / bestScore; - } - if (score < minDist) { - bestSecondaryCentroid = centroidOrdinal; - minDist = score; - } + + static void assignCentroidsMerge( + CentroidAssignmentScorer scorer, + FloatVectorValues vectors, + IntArrayList[] clusters + ) throws IOException { + int numCentroids = scorer.size(); + // If soar > 0, then we actually need to apply the projection, otherwise, its just the second + // nearest centroid + // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible + int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); + int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); + // TODO is this the right to check? + // If cluster quality is higher, maybe we can reduce this... + NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); + OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); + float[] scratch = new float[vectors.dimension()]; + // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? + // We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing + for (int vecOrd = 0; vecOrd < vectors.size(); vecOrd++) { + float[] vector = vectors.vectorValue(vecOrd); + scorer.setScoringVector(vector); + int bestCentroid = 0; + float bestScore = Float.MAX_VALUE; + if (numCentroids > 1) { + for (short c = 0; c < numCentroids; c++) { + float squareDist = scorer.score(c); + neighborsToCheck.insertWithOverflow(c, squareDist); + } + int centroidCount = neighborsToCheck.size(); + int bestIdx = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); + ordScoreIterator.idx = centroidCount; + bestCentroid = ordScoreIterator.getOrd(bestIdx); + bestScore = ordScoreIterator.getScore(bestIdx); + } + if (clusters[bestCentroid] == null) { + clusters[bestCentroid] = new IntArrayList(16); + } + clusters[bestCentroid].add(vecOrd); + if (soarClusterCheckCount > 0) { + assignCentroidSOAR( + ordScoreIterator, + vecOrd, + bestCentroid, + scorer.centroid(bestCentroid), + bestScore, + scratch, + scorer, + vectors, + clusters + ); + } + neighborsToCheck.clear(); + } } - if (bestSecondaryCentroid != -1) { - clusters[bestSecondaryCentroid].add(docId); + + static void assignCentroidSOAR( + OrdScoreIterator centroidsToCheck, + int vecOrd, + int bestCentroidId, + float[] bestCentroid, + float bestScore, + float[] scratch, + CentroidAssignmentScorer scorer, + FloatVectorValues vectors, + IntArrayList[] clusters + ) throws IOException { + float[] vector = vectors.vectorValue(vecOrd); + ESVectorUtil.subtract(vector, bestCentroid, scratch); + int bestSecondaryCentroid = -1; + float minDist = Float.MAX_VALUE; + for (int i = 0; i < centroidsToCheck.size(); i++) { + float score = centroidsToCheck.getScore(i); + int centroidOrdinal = centroidsToCheck.getOrd(i); + if (centroidOrdinal == bestCentroidId) { + continue; + } + float proj = ESVectorUtil.soarResidual(vector, scorer.centroid(centroidOrdinal), scratch); + score += SOAR_LAMBDA * proj * proj / bestScore; + if (score < minDist) { + bestSecondaryCentroid = centroidOrdinal; + minDist = score; + } + } + if (bestSecondaryCentroid != -1) { + clusters[bestSecondaryCentroid].add(vecOrd); + } } - } - static class OrdScoreIterator { - private final int[] ords; - private final float[] scores; - private int idx = 0; + static class OrdScoreIterator { + private final int[] ords; + private final float[] scores; + private int idx = 0; - OrdScoreIterator(int size) { - this.ords = new int[size]; - this.scores = new float[size]; - } + OrdScoreIterator(int size) { + this.ords = new int[size]; + this.scores = new float[size]; + } - void add(int ord, float score) { - ords[idx] = ord; - scores[idx] = score; - idx++; - } + int getOrd(int idx) { + return ords[idx]; + } - int getOrd(int idx) { - return ords[idx]; - } + float getScore(int idx) { + return scores[idx]; + } - float getScore(int idx) { - return scores[idx]; + int size() { + return idx; + } } - void reset() { - idx = 0; - } + // TODO unify with OSQ format + static class BinarizedFloatVectorValues { + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final byte[] binarized; + private final byte[] initQuantized; + private float[] centroid; + private final FloatVectorValues values; + private final OptimizedScalarQuantizer quantizer; - int size() { - return idx; - } - } - - // TODO unify with OSQ format - static class BinarizedFloatVectorValues { - private OptimizedScalarQuantizer.QuantizationResult corrections; - private final byte[] binarized; - private final byte[] initQuantized; - private float[] centroid; - private final FloatVectorValues values; - private final OptimizedScalarQuantizer quantizer; - - private int lastOrd = -1; - - BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) { - this.values = delegate; - this.quantizer = quantizer; - this.binarized = new byte[discretize(delegate.dimension(), 64) / 8]; - this.initQuantized = new byte[delegate.dimension()]; - } + private int lastOrd = -1; - public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { - if (ord != lastOrd) { - throw new IllegalStateException( - "attempt to retrieve corrective terms for different ord " - + ord - + " than the quantization was done for: " - + lastOrd); - } - return corrections; - } + BinarizedFloatVectorValues(FloatVectorValues delegate, OptimizedScalarQuantizer quantizer) { + this.values = delegate; + this.quantizer = quantizer; + this.binarized = new byte[discretize(delegate.dimension(), 64) / 8]; + this.initQuantized = new byte[delegate.dimension()]; + } - public byte[] vectorValue(int ord) throws IOException { - if (ord != lastOrd) { - binarize(ord); - lastOrd = ord; - } - return binarized; - } + public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) { + if (ord != lastOrd) { + throw new IllegalStateException( + "attempt to retrieve corrective terms for different ord " + ord + " than the quantization was done for: " + lastOrd + ); + } + return corrections; + } - private void binarize(int ord) throws IOException { - corrections = - quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid); - packAsBinary(initQuantized, binarized); - } - } - - static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { - private final IndexInput centroidsInput; - private final int numCentroids; - private final int dimension; - private final float[] scratch; - private float[] q; - private final long centroidByteSize; - private int currOrd = -1; - - OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) { - this.centroidsInput = centroidsInput; - this.numCentroids = numCentroids; - this.dimension = info.getVectorDimension(); - this.scratch = new float[dimension]; - this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; - } + public byte[] vectorValue(int ord) throws IOException { + if (ord != lastOrd) { + binarize(ord); + lastOrd = ord; + } + return binarized; + } - @Override - public int size() { - return numCentroids; + private void binarize(int ord) throws IOException { + corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid); + packAsBinary(initQuantized, binarized); + } } - @Override - public float[] centroid(int centroidOrdinal) throws IOException { - if (centroidOrdinal == currOrd) { - return scratch; - } - centroidsInput.seek( - numCentroids * centroidByteSize + (long) centroidOrdinal * dimension * Float.BYTES); - centroidsInput.readFloats(scratch, 0, dimension); - this.currOrd = centroidOrdinal; - return scratch; - } + static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + private final IndexInput centroidsInput; + private final int numCentroids; + private final int dimension; + private final float[] scratch; + private float[] q; + private final long centroidByteSize; + private int currOrd = -1; + + OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) { + this.centroidsInput = centroidsInput; + this.numCentroids = numCentroids; + this.dimension = info.getVectorDimension(); + this.scratch = new float[dimension]; + this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + } - @Override - public void setScoringVector(float[] vector) { - q = vector; - } + @Override + public int size() { + return numCentroids; + } - @Override - public float score(int centroidOrdinal) throws IOException { - return VectorUtil.squareDistance(centroid(centroidOrdinal), q); - } - } + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currOrd) { + return scratch; + } + centroidsInput.seek(numCentroids * centroidByteSize + (long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.readFloats(scratch, 0, dimension); + this.currOrd = centroidOrdinal; + return scratch; + } - // TODO throw away rawCentroids - static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { - private final float[][] centroids; - private float[] q; + @Override + public void setScoringVector(float[] vector) { + q = vector; + } - OnHeapCentroidAssignmentScorer(float[][] centroids) { - this.centroids = centroids; + @Override + public float score(int centroidOrdinal) throws IOException { + return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + } } - @Override - public int size() { - return centroids.length; - } + // TODO throw away rawCentroids + static class OnHeapCentroidAssignmentScorer implements CentroidAssignmentScorer { + private final float[][] centroids; + private float[] q; - @Override - public void setScoringVector(float[] vector) { - q = vector; - } + OnHeapCentroidAssignmentScorer(float[][] centroids) { + this.centroids = centroids; + } - @Override - public float[] centroid(int centroidOrdinal) throws IOException { - return centroids[centroidOrdinal]; + @Override + public int size() { + return centroids.length; + } + + @Override + public void setScoringVector(float[] vector) { + q = vector; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + return centroids[centroidOrdinal]; + } + + @Override + public float score(int centroidOrdinal) throws IOException { + return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + } } - @Override - public float score(int centroidOrdinal) throws IOException { - return VectorUtil.squareDistance(centroid(centroidOrdinal), q); + static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) + throws IOException { + indexOutput.writeBytes(binaryValue, binaryValue.length); + indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + indexOutput.writeShort((short) corrections.quantizedComponentSum()); } - } - - static void writeQuantizedValue( - IndexOutput indexOutput, - byte[] binaryValue, - OptimizedScalarQuantizer.QuantizationResult corrections) - throws IOException { - indexOutput.writeBytes(binaryValue, binaryValue.length); - indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); - indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); - indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); - assert corrections.quantizedComponentSum() >= 0 - && corrections.quantizedComponentSum() <= 0xffff; - indexOutput.writeShort((short) corrections.quantizedComponentSum()); - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java index f7cf9a7bcdba5..230d7743726eb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -9,7 +9,6 @@ package org.elasticsearch.index.codec.vectors; -import java.io.IOException; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.KnnVectorsWriter; @@ -19,74 +18,86 @@ import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.common.util.FeatureFlag; + +import java.io.IOException; /** * Codec format for Inverted File Vector indexes. This index expects to break the dimensional space * into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters * are represented by centroids. + *

    + * Each posting list is individual quantized and stored in-line allowing for fast block scoring. * *

    THe index is searcher by looking for the closest centroids to our vector query and then * scoring the vectors in the posting list of the closest centroids. */ public class IVFVectorsFormat extends KnnVectorsFormat { - public static final String IVF_VECTOR_COMPONENT = "IVF"; - public static final String NAME = "IVFVectorsFormat"; - // centroid ordinals -> centroid values, offsets - public static final String CENTROID_EXTENSION = "cenivf"; - // offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized - // vectors (OSQ bit) - public static final String CLUSTER_EXTENSION = "clivf"; - static final String IVF_META_EXTENSION = "mivf"; - - public static final int VERSION_START = 0; - public static final int VERSION_CURRENT = VERSION_START; - - private static final FlatVectorsFormat rawVectorFormat = - new Lucene99FlatVectorsFormat(FlatVectorScorerUtil.getLucene99FlatVectorsScorer()); - - private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000; - - private final int vectorPerCluster; - - public IVFVectorsFormat(int vectorPerCluster) { - super(NAME); - this.vectorPerCluster = vectorPerCluster; - } - - /** Constructs a format using the given graph construction parameters and scalar quantization. */ - public IVFVectorsFormat() { - this(DEFAULT_VECTORS_PER_CLUSTER); - } - - @Override - public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new DefaultIVFVectorsWriter( - state, rawVectorFormat.fieldsWriter(state), vectorPerCluster); - } - - @Override - public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { - return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state)); - } - - @Override - public int getMaxDimensions(String fieldName) { - return 1024; - } - - @Override - public String toString() { - return "IVFVectorFormat"; - } - - static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); + static final FeatureFlag IVF_FORMAT_FEATURE_FLAG = new FeatureFlag("ivf_format"); + public static final String IVF_VECTOR_COMPONENT = "IVF"; + public static final String NAME = "IVFVectorsFormat"; + // centroid ordinals -> centroid values, offsets + public static final String CENTROID_EXTENSION = "cenivf"; + // offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized + // vectors (OSQ bit) + public static final String CLUSTER_EXTENSION = "clivf"; + static final String IVF_META_EXTENSION = "mivf"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + + private static final int DEFAULT_VECTORS_PER_CLUSTER = 1000; + + private final int vectorPerCluster; + + public IVFVectorsFormat(int vectorPerCluster) { + super(NAME); + if (IVF_FORMAT_FEATURE_FLAG.isEnabled() == false) { + throw new IllegalStateException("IVF format is not enabled"); + } + this.vectorPerCluster = vectorPerCluster; + } + + /** Constructs a format using the given graph construction parameters and scalar quantization. */ + public IVFVectorsFormat() { + this(DEFAULT_VECTORS_PER_CLUSTER); + if (IVF_FORMAT_FEATURE_FLAG.isEnabled() == false) { + throw new IllegalStateException("IVF format is not enabled"); + } } - if (vectorsReader instanceof IVFVectorsReader reader) { - return reader; + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + return new DefaultIVFVectorsReader(state, rawVectorFormat.fieldsReader(state)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 1024; + } + + @Override + public String toString() { + return "IVFVectorFormat"; + } + + static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof IVFVectorsReader reader) { + return reader; + } + return null; } - return null; - } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index d8a14f55894c4..6d4e6602cbac1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -7,14 +7,8 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ - package org.elasticsearch.index.codec.vectors; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import java.util.function.IntPredicate; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnVectorsReader; import org.apache.lucene.codecs.hnsw.FlatVectorsReader; @@ -24,8 +18,6 @@ import org.apache.lucene.index.FieldInfos; import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.IndexFileNames; -import org.apache.lucene.index.KnnVectorValues; -import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.VectorEncoding; import org.apache.lucene.index.VectorSimilarityFunction; @@ -41,434 +33,322 @@ import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.NeighborQueue; +import java.io.IOException; +import java.util.function.IntPredicate; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; + /** - * @lucene.experimental + * Reader for IVF vectors. This reader is used to read the IVF vectors from the index. */ public abstract class IVFVectorsReader extends KnnVectorsReader { - private final IndexInput ivfCentroids, ivfClusters; - private final SegmentReadState state; - private final FieldInfos fieldInfos; - protected final IntObjectHashMap fields; - private final FlatVectorsReader rawVectorsReader; - - protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) - throws IOException { - this.state = state; - this.fieldInfos = state.fieldInfos; - this.rawVectorsReader = rawVectorsReader; - this.fields = new IntObjectHashMap<>(); - String meta = - IndexFileNames.segmentFileName( - state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION); - - int versionMeta = -1; - boolean success = false; - try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) { - Throwable priorE = null; - try { - versionMeta = - CodecUtil.checkIndexHeader( - ivfMeta, - IVFVectorsFormat.NAME, - IVFVectorsFormat.VERSION_START, - IVFVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); - readFields(ivfMeta); - } catch (Throwable exception) { - priorE = exception; - } finally { - CodecUtil.checkFooter(ivfMeta, priorE); - } - ivfCentroids = - openDataInput( - state, - versionMeta, - IVFVectorsFormat.CENTROID_EXTENSION, - IVFVectorsFormat.NAME, - state.context); - ivfClusters = - openDataInput( - state, - versionMeta, - IVFVectorsFormat.CLUSTER_EXTENSION, - IVFVectorsFormat.NAME, - state.context); - success = true; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(this); - } - } - } - - abstract CentroidQueryScorer getCentroidScorer( - FieldInfo fieldInfo, - int numCentroids, - IndexInput centroids, - float[] target, - IndexInput clusters) - throws IOException; - - protected abstract FloatVectorValues getCentroids( - IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException; - - record CentroidInfo(CentroidFloatVectorValues vectors, float innerProduct) {} - - CentroidInfo centroidVectors(String fieldName, int centroidOrd, MergeState.DocMap docMap) - throws IOException { - FieldInfo info = state.fieldInfos.fieldInfo(fieldName); - FieldEntry entry = fields.get(info.number); - if (entry == null) { - return null; - } - if (entry.vectorEncoding() == VectorEncoding.BYTE) { - return null; + private final IndexInput ivfCentroids, ivfClusters; + private final SegmentReadState state; + private final FieldInfos fieldInfos; + protected final IntObjectHashMap fields; + private final FlatVectorsReader rawVectorsReader; + + protected IVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { + this.state = state; + this.fieldInfos = state.fieldInfos; + this.rawVectorsReader = rawVectorsReader; + this.fields = new IntObjectHashMap<>(); + String meta = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION); + + int versionMeta = -1; + boolean success = false; + try (ChecksumIndexInput ivfMeta = state.directory.openChecksumInput(meta)) { + Throwable priorE = null; + try { + versionMeta = CodecUtil.checkIndexHeader( + ivfMeta, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_START, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + readFields(ivfMeta); + } catch (Throwable exception) { + priorE = exception; + } finally { + CodecUtil.checkFooter(ivfMeta, priorE); + } + ivfCentroids = openDataInput(state, versionMeta, IVFVectorsFormat.CENTROID_EXTENSION, IVFVectorsFormat.NAME, state.context); + ivfClusters = openDataInput(state, versionMeta, IVFVectorsFormat.CLUSTER_EXTENSION, IVFVectorsFormat.NAME, state.context); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } } - ivfClusters.seek(entry.postingListOffsets()[centroidOrd]); - int vectors = ivfClusters.readVInt(); - float innerProduct = Float.intBitsToFloat(ivfClusters.readInt()); - int[] vectorDocIds = new int[vectors]; - DocIdsWriter docIdsWriter = new DocIdsWriter(); - docIdsWriter.readInts(ivfClusters, vectors, vectorDocIds); - - // TODO this assumes that vectorDocIds are sorted!!! - int count = 0; - for (int i = 0; i < vectors; i++) { - int docId = vectorDocIds[i]; - if (docMap.get(docId) != -1) { - ++count; - } + + abstract CentroidQueryScorer getCentroidScorer( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] target, + IndexInput clusters + ) throws IOException; + + protected abstract FloatVectorValues getCentroids(IndexInput indexInput, int numCentroids, FieldInfo info) throws IOException; + + public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException { + FieldEntry entry = fields.get(fieldInfo.number); + if (entry == null) { + return null; + } + return getCentroids(entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo); } - // TODO: Do we need random access? If so, we should gather the ordinals here by - // iterating the valid docs in the docMap, keeping track of the valid ordinals, then they can - // be directly - // accessed - FloatVectorValues vectorValues = getFloatVectorValues(fieldName); - CentroidFloatVectorValues centroidFloatVectorValues = - new CentroidFloatVectorValues(vectorValues, vectorDocIds, docMap, count); - return new CentroidInfo(centroidFloatVectorValues, innerProduct); - } - - static class CentroidFloatVectorValues { - final FloatVectorValues vectorValues; - final int[] docIds; - final MergeState.DocMap docMap; - final int size; - int curOriginalDocId = -1; - int mappedDocID = -1; - KnnVectorValues.DocIndexIterator iterator; - - CentroidFloatVectorValues( - FloatVectorValues vectorValues, int[] docIds, MergeState.DocMap docMap, int size) { - this.vectorValues = vectorValues; - this.iterator = vectorValues.iterator(); - this.docIds = docIds; - this.docMap = docMap; - this.size = size; + + int centroidSize(String fieldName, int centroidOrdinal) throws IOException { + FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName); + FieldEntry entry = fields.get(fieldInfo.number); + ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]); + return ivfClusters.readVInt(); } - int docId() { - return mappedDocID; + private static IndexInput openDataInput( + SegmentReadState state, + int versionMeta, + String fileExtension, + String codecName, + IOContext context + ) throws IOException { + final String fileName = IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); + final IndexInput in = state.directory.openInput(fileName, context); + boolean success = false; + try { + final int versionVectorData = CodecUtil.checkIndexHeader( + in, + codecName, + IVFVectorsFormat.VERSION_START, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + if (versionMeta != versionVectorData) { + throw new CorruptIndexException( + "Format versions mismatch: meta=" + versionMeta + ", " + codecName + "=" + versionVectorData, + in + ); + } + CodecUtil.retrieveChecksum(in); + success = true; + return in; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(in); + } + } } - float[] vectorValue() throws IOException { - return vectorValues.vectorValue(iterator.index()); + private void readFields(ChecksumIndexInput meta) throws IOException { + for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { + final FieldInfo info = fieldInfos.fieldInfo(fieldNumber); + if (info == null) { + throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); + } + fields.put(info.number, readField(meta, info)); + } } - int nextVectorDocId() throws IOException { - while (curOriginalDocId < docIds.length - 1) { - curOriginalDocId++; - int doc = iterator.advance(docIds[curOriginalDocId]); - if (doc == NO_MORE_DOCS) { - return this.mappedDocID = NO_MORE_DOCS; + private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { + final VectorEncoding vectorEncoding = readVectorEncoding(input); + final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); + final long centroidOffset = input.readLong(); + final long centroidLength = input.readLong(); + final int numPostingLists = input.readVInt(); + final long[] postingListOffsets = new long[numPostingLists]; + for (int i = 0; i < numPostingLists; i++) { + postingListOffsets[i] = input.readLong(); + } + final float[] globalCentroid = new float[info.getVectorDimension()]; + float globalCentroidDp = 0; + if (numPostingLists > 0) { + input.readFloats(globalCentroid, 0, globalCentroid.length); + globalCentroidDp = Float.intBitsToFloat(input.readInt()); } - int mappedDoc = docMap.get(doc); - if (mappedDoc != -1) { - return this.mappedDocID = mappedDoc; + if (similarityFunction != info.getVectorSimilarityFunction()) { + throw new IllegalStateException( + "Inconsistent vector similarity function for field=\"" + + info.name + + "\"; " + + similarityFunction + + " != " + + info.getVectorSimilarityFunction() + ); } - } - return this.mappedDocID = NO_MORE_DOCS; + return new FieldEntry( + similarityFunction, + vectorEncoding, + centroidOffset, + centroidLength, + postingListOffsets, + globalCentroid, + globalCentroidDp + ); } - } - public FloatVectorValues getCentroids(FieldInfo fieldInfo) throws IOException { - FieldEntry entry = fields.get(fieldInfo.number); - if (entry == null) { - return null; - } - return getCentroids( - entry.centroidSlice(ivfCentroids), entry.postingListOffsets.length, fieldInfo); - } - - int centroidSize(String fieldName, int centroidOrdinal) throws IOException { - FieldInfo fieldInfo = state.fieldInfos.fieldInfo(fieldName); - FieldEntry entry = fields.get(fieldInfo.number); - ivfClusters.seek(entry.postingListOffsets[centroidOrdinal]); - return ivfClusters.readVInt(); - } - - private static IndexInput openDataInput( - SegmentReadState state, - int versionMeta, - String fileExtension, - String codecName, - IOContext context) - throws IOException { - final String fileName = - IndexFileNames.segmentFileName(state.segmentInfo.name, state.segmentSuffix, fileExtension); - final IndexInput in = state.directory.openInput(fileName, context); - boolean success = false; - try { - final int versionVectorData = - CodecUtil.checkIndexHeader( - in, - codecName, - IVFVectorsFormat.VERSION_START, - IVFVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); - if (versionMeta != versionVectorData) { - throw new CorruptIndexException( - "Format versions mismatch: meta=" - + versionMeta - + ", " - + codecName - + "=" - + versionVectorData, - in); - } - CodecUtil.retrieveChecksum(in); - success = true; - return in; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(in); - } - } - } - - private void readFields(ChecksumIndexInput meta) throws IOException { - for (int fieldNumber = meta.readInt(); fieldNumber != -1; fieldNumber = meta.readInt()) { - final FieldInfo info = fieldInfos.fieldInfo(fieldNumber); - if (info == null) { - throw new CorruptIndexException("Invalid field number: " + fieldNumber, meta); - } - fields.put(info.number, readField(meta, info)); - } - } - - private FieldEntry readField(IndexInput input, FieldInfo info) throws IOException { - final VectorEncoding vectorEncoding = readVectorEncoding(input); - final VectorSimilarityFunction similarityFunction = readSimilarityFunction(input); - final long centroidOffset = input.readLong(); - final long centroidLength = input.readLong(); - final int numPostingLists = input.readVInt(); - final long[] postingListOffsets = new long[numPostingLists]; - for (int i = 0; i < numPostingLists; i++) { - postingListOffsets[i] = input.readLong(); - } - final float[] globalCentroid = new float[info.getVectorDimension()]; - float globalCentroidDp = 0; - if (numPostingLists > 0) { - input.readFloats(globalCentroid, 0, globalCentroid.length); - globalCentroidDp = Float.intBitsToFloat(input.readInt()); - } - if (similarityFunction != info.getVectorSimilarityFunction()) { - throw new IllegalStateException( - "Inconsistent vector similarity function for field=\"" - + info.name - + "\"; " - + similarityFunction - + " != " - + info.getVectorSimilarityFunction()); - } - return new FieldEntry( - similarityFunction, - vectorEncoding, - centroidOffset, - centroidLength, - postingListOffsets, - globalCentroid, - globalCentroidDp); - } - - private static VectorSimilarityFunction readSimilarityFunction(DataInput input) - throws IOException { - final int i = input.readInt(); - if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { - throw new IllegalArgumentException("invalid distance function: " + i); + private static VectorSimilarityFunction readSimilarityFunction(DataInput input) throws IOException { + final int i = input.readInt(); + if (i < 0 || i >= SIMILARITY_FUNCTIONS.size()) { + throw new IllegalArgumentException("invalid distance function: " + i); + } + return SIMILARITY_FUNCTIONS.get(i); } - return SIMILARITY_FUNCTIONS.get(i); - } - private static VectorEncoding readVectorEncoding(DataInput input) throws IOException { - final int encodingId = input.readInt(); - if (encodingId < 0 || encodingId >= VectorEncoding.values().length) { - throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input); - } - return VectorEncoding.values()[encodingId]; - } - - @Override - public final void checkIntegrity() throws IOException { - rawVectorsReader.checkIntegrity(); - CodecUtil.checksumEntireFile(ivfCentroids); - CodecUtil.checksumEntireFile(ivfClusters); - } - - @Override - public final FloatVectorValues getFloatVectorValues(String field) throws IOException { - return rawVectorsReader.getFloatVectorValues(field); - } - - @Override - public final ByteVectorValues getByteVectorValues(String field) throws IOException { - return rawVectorsReader.getByteVectorValues(field); - } - - protected float[] getGlobalCentroid(FieldInfo info) { - if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) { - return null; + private static VectorEncoding readVectorEncoding(DataInput input) throws IOException { + final int encodingId = input.readInt(); + if (encodingId < 0 || encodingId >= VectorEncoding.values().length) { + throw new CorruptIndexException("Invalid vector encoding id: " + encodingId, input); + } + return VectorEncoding.values()[encodingId]; } - FieldEntry entry = fields.get(info.number); - if (entry == null) { - return null; + + @Override + public final void checkIntegrity() throws IOException { + rawVectorsReader.checkIntegrity(); + CodecUtil.checksumEntireFile(ivfCentroids); + CodecUtil.checksumEntireFile(ivfClusters); } - return entry.globalCentroid(); - } - - @Override - public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) - throws IOException { - final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) { - rawVectorsReader.search(field, target, knnCollector, acceptDocs); - return; + + @Override + public final FloatVectorValues getFloatVectorValues(String field) throws IOException { + return rawVectorsReader.getFloatVectorValues(field); } - int nProbe = -1; - if (knnCollector.getSearchStrategy() instanceof IVFKnnSearchStrategy ivfStrategy) { - nProbe = ivfStrategy.getNProbe(); + + @Override + public final ByteVectorValues getByteVectorValues(String field) throws IOException { + return rawVectorsReader.getByteVectorValues(field); } - float percentFiltered = 1f; - if (acceptDocs instanceof BitSet bitSet) { - percentFiltered = - Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); + + protected float[] getGlobalCentroid(FieldInfo info) { + if (info == null || info.getVectorEncoding().equals(VectorEncoding.BYTE)) { + return null; + } + FieldEntry entry = fields.get(info.number); + if (entry == null) { + return null; + } + return entry.globalCentroid(); } - int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); - BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1); - // TODO can we make a conjunction between idSetIterator and the acceptDocs? - IntPredicate needsScoring = - docId -> { - if (acceptDocs != null && acceptDocs.get(docId) == false) { - return false; - } - return visitedDocs.getAndSet(docId) == false; + + @Override + public final void search(String field, float[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32) == false) { + rawVectorsReader.search(field, target, knnCollector, acceptDocs); + return; + } + // TODO add new ivf search strategy + int nProbe = 10; + float percentFiltered = 1f; + if (acceptDocs instanceof BitSet bitSet) { + percentFiltered = Math.max(0f, Math.min(1f, (float) bitSet.approximateCardinality() / bitSet.length())); + } + int numVectors = rawVectorsReader.getFloatVectorValues(field).size(); + BitSet visitedDocs = new FixedBitSet(state.segmentInfo.maxDoc() + 1); + IntPredicate needsScoring = docId -> { + if (acceptDocs != null && acceptDocs.get(docId) == false) { + return false; + } + return visitedDocs.getAndSet(docId) == false; }; - FieldEntry entry = fields.get(fieldInfo.number); - CentroidQueryScorer centroidQueryScorer = - getCentroidScorer( + FieldEntry entry = fields.get(fieldInfo.number); + CentroidQueryScorer centroidQueryScorer = getCentroidScorer( fieldInfo, entry.postingListOffsets.length, entry.centroidSlice(ivfCentroids), target, - ivfClusters); - int centroidsToSearch = nProbe; - if (centroidsToSearch <= 0) { - centroidsToSearch = Math.max(((knnCollector.k() * 300) / 1_000), 1); - } - final NeighborQueue centroidQueue = - scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe); - PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); - int centroidsVisited = 0; - long expectedDocs = 0; - long actualDocs = 0; - // initially we visit only the "centroids to search" - while (centroidQueue.size() > 0 && centroidsVisited < centroidsToSearch) { - ++centroidsVisited; - // todo do we actually need to know the score??? - int centroidOrdinal = centroidQueue.pop(); - // todo do we need direct access to the raw centroid??? - expectedDocs += - scorer.resetPostingsScorer( - centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); - actualDocs += scorer.visit(knnCollector); + ivfClusters + ); + final NeighborQueue centroidQueue = scorePostingLists(fieldInfo, knnCollector, centroidQueryScorer, nProbe); + PostingVisitor scorer = getPostingVisitor(fieldInfo, ivfClusters, target, needsScoring); + int centroidsVisited = 0; + long expectedDocs = 0; + long actualDocs = 0; + // initially we visit only the "centroids to search" + while (centroidQueue.size() > 0 && centroidsVisited < nProbe) { + ++centroidsVisited; + // todo do we actually need to know the score??? + int centroidOrdinal = centroidQueue.pop(); + // todo do we need direct access to the raw centroid??? + expectedDocs += scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + actualDocs += scorer.visit(knnCollector); + } + if (acceptDocs != null) { + float unfilteredRatioVisited = (float) expectedDocs / numVectors; + int filteredVectors = (int) Math.ceil(numVectors * percentFiltered); + float expectedScored = Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); + while (centroidQueue.size() > 0 && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { + int centroidOrdinal = centroidQueue.pop(); + scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); + actualDocs += scorer.visit(knnCollector); + } + } } - if (acceptDocs != null) { - float unfilteredRatioVisited = (float) expectedDocs / numVectors; - int filteredVectors = (int) Math.ceil(numVectors * percentFiltered); - float expectedScored = - Math.min(2 * filteredVectors * unfilteredRatioVisited, expectedDocs / 2f); - while (centroidQueue.size() > 0 - && (actualDocs < expectedScored || actualDocs < knnCollector.k())) { - int centroidOrdinal = centroidQueue.pop(); - scorer.resetPostingsScorer(centroidOrdinal, centroidQueryScorer.centroid(centroidOrdinal)); - actualDocs += scorer.visit(knnCollector); - } + + @Override + public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) throws IOException { + final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); + final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field); + for (int i = 0; i < values.size(); i++) { + final float score = fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i)); + knnCollector.collect(values.ordToDoc(i), score); + if (knnCollector.earlyTerminated()) { + return; + } + } } - } - - @Override - public final void search(String field, byte[] target, KnnCollector knnCollector, Bits acceptDocs) - throws IOException { - final FieldInfo fieldInfo = state.fieldInfos.fieldInfo(field); - final ByteVectorValues values = rawVectorsReader.getByteVectorValues(field); - for (int i = 0; i < values.size(); i++) { - final float score = - fieldInfo.getVectorSimilarityFunction().compare(target, values.vectorValue(i)); - knnCollector.collect(values.ordToDoc(i), score); - if (knnCollector.earlyTerminated()) { - return; - } + + abstract NeighborQueue scorePostingLists( + FieldInfo fieldInfo, + KnnCollector knnCollector, + CentroidQueryScorer centroidQueryScorer, + int nProbe + ) throws IOException; + + @Override + public void close() throws IOException { + IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters); } - } - - abstract NeighborQueue scorePostingLists( - FieldInfo fieldInfo, - KnnCollector knnCollector, - CentroidQueryScorer centroidQueryScorer, - int nProbe) - throws IOException; - - @Override - public void close() throws IOException { - IOUtils.close(rawVectorsReader, ivfCentroids, ivfClusters); - } - - protected record FieldEntry( - VectorSimilarityFunction similarityFunction, - VectorEncoding vectorEncoding, - long centroidOffset, - long centroidLength, - long[] postingListOffsets, - float[] globalCentroid, - float globalCentroidDp) { - IndexInput centroidSlice(IndexInput centroidFile) throws IOException { - return centroidFile.slice("centroids", centroidOffset, centroidLength); + + protected record FieldEntry( + VectorSimilarityFunction similarityFunction, + VectorEncoding vectorEncoding, + long centroidOffset, + long centroidLength, + long[] postingListOffsets, + float[] globalCentroid, + float globalCentroidDp + ) { + IndexInput centroidSlice(IndexInput centroidFile) throws IOException { + return centroidFile.slice("centroids", centroidOffset, centroidLength); + } } - } - abstract PostingVisitor getPostingVisitor( - FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) - throws IOException; + abstract PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput postingsLists, float[] target, IntPredicate needsScoring) + throws IOException; - interface CentroidQueryScorer { - int size(); + interface CentroidQueryScorer { + int size(); - float[] centroid(int centroidOrdinal) throws IOException; + float[] centroid(int centroidOrdinal) throws IOException; - float score(int centroidOrdinal) throws IOException; - } + float score(int centroidOrdinal) throws IOException; + } - interface PostingVisitor { - // TODO maybe we can not specifically pass the centroid... + interface PostingVisitor { + // TODO maybe we can not specifically pass the centroid... - /** returns the number of documents in the posting list */ - int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException; + /** returns the number of documents in the posting list */ + int resetPostingsScorer(int centroidOrdinal, float[] centroid) throws IOException; - /** returns the number of scored documents */ - int visit(KnnCollector collector) throws IOException; - } + /** returns the number of scored documents */ + int visit(KnnCollector collector) throws IOException; + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 4011576ecd47f..07e0a389c47d1 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -7,18 +7,8 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ - package org.elasticsearch.index.codec.vectors; -import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; -import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; - -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.util.ArrayList; -import java.util.List; import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.codecs.KnnFieldVectorsWriter; import org.apache.lucene.codecs.KnnVectorsReader; @@ -43,456 +33,479 @@ import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.VectorUtil; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.List; + +import static org.apache.lucene.codecs.lucene99.Lucene99HnswVectorsReader.SIMILARITY_FUNCTIONS; +import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; + /** - * @lucene.experimental + * Base class for IVF vectors writer. */ public abstract class IVFVectorsWriter extends KnnVectorsWriter { - private final List fieldWriters = new ArrayList<>(); - private final IndexOutput ivfCentroids, ivfClusters; - private final IndexOutput ivfMeta; - private final FlatVectorsWriter rawVectorDelegate; - private final SegmentWriteState segmentWriteState; - - protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) - throws IOException { - this.segmentWriteState = state; - this.rawVectorDelegate = rawVectorDelegate; - final String metaFileName = - IndexFileNames.segmentFileName( - state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.IVF_META_EXTENSION); - - final String ivfCentroidsFileName = - IndexFileNames.segmentFileName( - state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.CENTROID_EXTENSION); - final String ivfClustersFileName = - IndexFileNames.segmentFileName( - state.segmentInfo.name, state.segmentSuffix, IVFVectorsFormat.CLUSTER_EXTENSION); - boolean success = false; - try { - ivfMeta = state.directory.createOutput(metaFileName, state.context); - CodecUtil.writeIndexHeader( - ivfMeta, - IVFVectorsFormat.NAME, - IVFVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); - ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context); - CodecUtil.writeIndexHeader( - ivfCentroids, - IVFVectorsFormat.NAME, - IVFVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); - ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context); - CodecUtil.writeIndexHeader( - ivfClusters, - IVFVectorsFormat.NAME, - IVFVectorsFormat.VERSION_CURRENT, - state.segmentInfo.getId(), - state.segmentSuffix); - success = true; - } finally { - if (success == false) { - IOUtils.closeWhileHandlingException(this); - } + private final List fieldWriters = new ArrayList<>(); + private final IndexOutput ivfCentroids, ivfClusters; + private final IndexOutput ivfMeta; + private final FlatVectorsWriter rawVectorDelegate; + private final SegmentWriteState segmentWriteState; + + protected IVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate) throws IOException { + this.segmentWriteState = state; + this.rawVectorDelegate = rawVectorDelegate; + final String metaFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + IVFVectorsFormat.IVF_META_EXTENSION + ); + + final String ivfCentroidsFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + IVFVectorsFormat.CENTROID_EXTENSION + ); + final String ivfClustersFileName = IndexFileNames.segmentFileName( + state.segmentInfo.name, + state.segmentSuffix, + IVFVectorsFormat.CLUSTER_EXTENSION + ); + boolean success = false; + try { + ivfMeta = state.directory.createOutput(metaFileName, state.context); + CodecUtil.writeIndexHeader( + ivfMeta, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + ivfCentroids = state.directory.createOutput(ivfCentroidsFileName, state.context); + CodecUtil.writeIndexHeader( + ivfCentroids, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + ivfClusters = state.directory.createOutput(ivfClustersFileName, state.context); + CodecUtil.writeIndexHeader( + ivfClusters, + IVFVectorsFormat.NAME, + IVFVectorsFormat.VERSION_CURRENT, + state.segmentInfo.getId(), + state.segmentSuffix + ); + success = true; + } finally { + if (success == false) { + IOUtils.closeWhileHandlingException(this); + } + } } - } - @Override - public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { - if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) { - throw new IllegalArgumentException("IVF does not support cosine similarity"); - } - final FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - @SuppressWarnings("unchecked") - final FlatFieldVectorsWriter floatWriter = - (FlatFieldVectorsWriter) rawVectorDelegate; - fieldWriters.add(new FieldWriter(fieldInfo, floatWriter)); - } - return rawVectorDelegate; - } - - protected abstract int calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput temporaryCentroidOutput, - MergeState mergeState, - float[] globalCentroid) - throws IOException; - - abstract long[] buildAndWritePostingsLists( - FieldInfo fieldInfo, - CentroidAssignmentScorer scorer, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput, - MergeState mergeState) - throws IOException; - - abstract CentroidAssignmentScorer calculateAndWriteCentroids( - FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput centroidOutput, - float[] globalCentroid) - throws IOException; - - abstract long[] buildAndWritePostingsLists( - FieldInfo fieldInfo, - InfoStream infoStream, - CentroidAssignmentScorer scorer, - FloatVectorValues floatVectorValues, - IndexOutput postingsOutput) - throws IOException; - - abstract CentroidAssignmentScorer createCentroidScorer( - IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, float[] globalCentroid) - throws IOException; - - @Override - public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { - rawVectorDelegate.flush(maxDoc, sortMap); - for (FieldWriter fieldWriter : fieldWriters) { - float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; - ESVectorUtil.calculateCentroid(fieldWriter.delegate().getVectors(), globalCentroid); - // build a float vector values with random access - final FloatVectorValues floatVectorValues = - getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); - // build centroids - long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - final CentroidAssignmentScorer centroidAssignmentScorer = - calculateAndWriteCentroids( - fieldWriter.fieldInfo, floatVectorValues, ivfCentroids, globalCentroid); - long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; - final long[] offsets = - buildAndWritePostingsLists( - fieldWriter.fieldInfo, - segmentWriteState.infoStream, - centroidAssignmentScorer, - floatVectorValues, - ivfClusters); - // write posting lists - writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); - } - } - - private static FloatVectorValues getFloatVectorValues( - FieldInfo fieldInfo, FlatFieldVectorsWriter fieldVectorsWriter, int maxDoc) - throws IOException { - List vectors = fieldVectorsWriter.getVectors(); - if (vectors.size() == maxDoc) { - return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension()); - } - final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator(); - final int[] docIds = new int[vectors.size()]; - for (int i = 0; i < docIds.length; i++) { - docIds[i] = iterator.nextDoc(); - } - assert iterator.nextDoc() == NO_MORE_DOCS; - return new FloatVectorValues() { - @Override - public float[] vectorValue(int ord) { - return vectors.get(ord); - } - - @Override - public FloatVectorValues copy() { - return this; - } - - @Override - public int dimension() { - return fieldInfo.getVectorDimension(); - } - - @Override - public int size() { - return vectors.size(); - } - - @Override - public int ordToDoc(int ord) { - return docIds[ord]; - } - }; - } - - static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { - if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { - vectorsReader = candidateReader.getFieldReader(fieldName); - } - if (vectorsReader instanceof IVFVectorsReader reader) { - return reader; - } - return null; - } - - @Override - public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { - rawVectorDelegate.mergeOneField(fieldInfo, mergeState); - if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { - final int numVectors; - String name = null; - boolean success = false; - // build a float vector values with random access. In order to do that we dump the vectors to - // a temporary file - // and write the docID follow by the vector - try (IndexOutput out = - mergeState.segmentInfo.dir.createTempOutput( - mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { - name = out.getName(); - // TODO do this better, we shouldn't have to write to a temp file, we should be able to - // to just from the merged vector values. - numVectors = - writeFloatVectorValues( - fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); - success = true; - } finally { - if (success == false && name != null) { - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + @Override + public final KnnFieldVectorsWriter addField(FieldInfo fieldInfo) throws IOException { + if (fieldInfo.getVectorSimilarityFunction() == VectorSimilarityFunction.COSINE) { + throw new IllegalArgumentException("IVF does not support cosine similarity"); + } + final FlatFieldVectorsWriter rawVectorDelegate = this.rawVectorDelegate.addField(fieldInfo); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + @SuppressWarnings("unchecked") + final FlatFieldVectorsWriter floatWriter = (FlatFieldVectorsWriter) rawVectorDelegate; + fieldWriters.add(new FieldWriter(fieldInfo, floatWriter)); } - } - float[] globalCentroid = new float[fieldInfo.getVectorDimension()]; - int vectorCount = 0; - for (var knnReaders : mergeState.knnVectorsReaders) { - IVFVectorsReader ivfReader = getIVFReader(knnReaders, fieldInfo.name); - if (ivfReader != null) { - int numVecs = ivfReader.getFloatVectorValues(fieldInfo.name).size(); - float[] readerGlobalCentroid = ivfReader.getGlobalCentroid(fieldInfo); - if (readerGlobalCentroid != null) { - vectorCount += numVecs; + return rawVectorDelegate; + } + + protected abstract int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid + ) throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidAssignmentScorer scorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState + ) throws IOException; + + abstract CentroidAssignmentScorer calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput centroidOutput, + float[] globalCentroid + ) throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + InfoStream infoStream, + CentroidAssignmentScorer scorer, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput + ) throws IOException; + + abstract CentroidAssignmentScorer createCentroidScorer( + IndexInput centroidsInput, + int numCentroids, + FieldInfo fieldInfo, + float[] globalCentroid + ) throws IOException; + + @Override + public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { + rawVectorDelegate.flush(maxDoc, sortMap); + for (FieldWriter fieldWriter : fieldWriters) { + float[] globalCentroid = new float[fieldWriter.fieldInfo.getVectorDimension()]; + // calculate global centroid + for (var vector : fieldWriter.delegate.getVectors()) { + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] += vector[i]; + } + } for (int i = 0; i < globalCentroid.length; i++) { - globalCentroid[i] += readerGlobalCentroid[i] * numVecs; + globalCentroid[i] /= fieldWriter.delegate.getVectors().size(); } - } - } - } - if (vectorCount > 0) { - for (int i = 0; i < globalCentroid.length; i++) { - globalCentroid[i] /= vectorCount; - } - } - try (IndexInput in = mergeState.segmentInfo.dir.openInput(name, IOContext.DEFAULT)) { - final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); - success = false; - CentroidAssignmentScorer centroidAssignmentScorer; - long centroidOffset; - long centroidLength; - String centroidTempName = null; - int numCentroids; - IndexOutput centroidTemp = null; - try { - centroidTemp = - mergeState.segmentInfo.dir.createTempOutput( - mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); - centroidTempName = centroidTemp.getName(); - numCentroids = - calculateAndWriteCentroids( - fieldInfo, floatVectorValues, centroidTemp, mergeState, globalCentroid); - success = true; - } finally { - if (success == false && centroidTempName != null) { - IOUtils.closeWhileHandlingException(centroidTemp); - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); - } - } - try { - if (numCentroids == 0) { - centroidOffset = ivfCentroids.getFilePointer(); - writeMeta(fieldInfo, centroidOffset, 0, new long[0], null); - CodecUtil.writeFooter(centroidTemp); - IOUtils.close(centroidTemp); - return; - } - CodecUtil.writeFooter(centroidTemp); - IOUtils.close(centroidTemp); - centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); - try (IndexInput centroidInput = - mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { - ivfCentroids.copyBytes( - centroidInput, centroidInput.length() - CodecUtil.footerLength()); - centroidLength = ivfCentroids.getFilePointer() - centroidOffset; - centroidAssignmentScorer = - createCentroidScorer(centroidInput, numCentroids, fieldInfo, globalCentroid); - assert centroidAssignmentScorer.size() == numCentroids; // build a float vector values with random access + final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldWriter.fieldInfo, fieldWriter.delegate, maxDoc); // build centroids - final long[] offsets = - buildAndWritePostingsLists( - fieldInfo, - centroidAssignmentScorer, - floatVectorValues, - ivfClusters, - mergeState); + long centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + final CentroidAssignmentScorer centroidAssignmentScorer = calculateAndWriteCentroids( + fieldWriter.fieldInfo, + floatVectorValues, + ivfCentroids, + globalCentroid + ); + long centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + final long[] offsets = buildAndWritePostingsLists( + fieldWriter.fieldInfo, + segmentWriteState.infoStream, + centroidAssignmentScorer, + floatVectorValues, + ivfClusters + ); // write posting lists + writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + } + } + + private static FloatVectorValues getFloatVectorValues( + FieldInfo fieldInfo, + FlatFieldVectorsWriter fieldVectorsWriter, + int maxDoc + ) throws IOException { + List vectors = fieldVectorsWriter.getVectors(); + if (vectors.size() == maxDoc) { + return FloatVectorValues.fromFloats(vectors, fieldInfo.getVectorDimension()); + } + final DocIdSetIterator iterator = fieldVectorsWriter.getDocsWithFieldSet().iterator(); + final int[] docIds = new int[vectors.size()]; + for (int i = 0; i < docIds.length; i++) { + docIds[i] = iterator.nextDoc(); + } + assert iterator.nextDoc() == NO_MORE_DOCS; + return new FloatVectorValues() { + @Override + public float[] vectorValue(int ord) { + return vectors.get(ord); + } - // TODO handle this correctly by creating new centroid - if (vectorCount == 0 && offsets.length > 0) { - throw new IllegalStateException( - "No global centroid found for field: " + fieldInfo.name); + @Override + public FloatVectorValues copy() { + return this; } - assert offsets.length == centroidAssignmentScorer.size(); - writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); - } - } finally { - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + + @Override + public int dimension() { + return fieldInfo.getVectorDimension(); + } + + @Override + public int size() { + return vectors.size(); + } + + @Override + public int ordToDoc(int ord) { + return docIds[ord]; + } + }; + } + + static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fieldName) { + if (vectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader candidateReader) { + vectorsReader = candidateReader.getFieldReader(fieldName); + } + if (vectorsReader instanceof IVFVectorsReader reader) { + return reader; } - } finally { - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); - } + return null; } - } - - private static FloatVectorValues getFloatVectorValues( - FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) { - final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES; - final float[] vector = new float[fieldInfo.getVectorDimension()]; - return new FloatVectorValues() { - @Override - public float[] vectorValue(int ord) throws IOException { - randomAccessInput.seek(ord * length + Integer.BYTES); - randomAccessInput.readFloats(vector, 0, vector.length); - return vector; - } - - @Override - public FloatVectorValues copy() { - return this; - } - - @Override - public int dimension() { - return fieldInfo.getVectorDimension(); - } - - @Override - public int size() { - return numVectors; - } - @Override - public int ordToDoc(int ord) { - try { - randomAccessInput.seek(ord * length); - return randomAccessInput.readInt(); - } catch (IOException e) { - throw new UncheckedIOException(e); + @Override + public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { + rawVectorDelegate.mergeOneField(fieldInfo, mergeState); + if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { + final int numVectors; + String name = null; + boolean success = false; + // build a float vector values with random access. In order to do that we dump the vectors to + // a temporary file + // and write the docID follow by the vector + try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { + name = out.getName(); + // TODO do this better, we shouldn't have to write to a temp file, we should be able to + // to just from the merged vector values. + numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + success = true; + } finally { + if (success == false && name != null) { + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + } + } + float[] globalCentroid = new float[fieldInfo.getVectorDimension()]; + int vectorCount = 0; + for (int idx = 0; idx < mergeState.knnVectorsReaders.length; idx++) { + if (mergeState.fieldInfos[idx] == null + || mergeState.fieldInfos[idx].hasVectorValues() == false + || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name) == null + || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name).hasVectorValues() == false) { + continue; + } + KnnVectorsReader knnReaders = mergeState.knnVectorsReaders[idx]; + IVFVectorsReader ivfReader = getIVFReader(knnReaders, fieldInfo.name); + if (ivfReader != null) { + FloatVectorValues floatVectorValues = knnReaders.getFloatVectorValues(fieldInfo.name); + if (floatVectorValues == null) { + continue; + } + int numVecs = floatVectorValues.size(); + float[] readerGlobalCentroid = ivfReader.getGlobalCentroid(fieldInfo); + if (readerGlobalCentroid != null) { + vectorCount += numVecs; + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] += readerGlobalCentroid[i] * numVecs; + } + } + } + } + if (vectorCount > 0) { + for (int i = 0; i < globalCentroid.length; i++) { + globalCentroid[i] /= vectorCount; + } + } + try (IndexInput in = mergeState.segmentInfo.dir.openInput(name, IOContext.DEFAULT)) { + final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); + success = false; + CentroidAssignmentScorer centroidAssignmentScorer; + long centroidOffset; + long centroidLength; + String centroidTempName = null; + int numCentroids; + IndexOutput centroidTemp = null; + try { + centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); + centroidTempName = centroidTemp.getName(); + numCentroids = calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidTemp, mergeState, globalCentroid); + success = true; + } finally { + if (success == false && centroidTempName != null) { + IOUtils.closeWhileHandlingException(centroidTemp); + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + } + } + try { + if (numCentroids == 0) { + centroidOffset = ivfCentroids.getFilePointer(); + writeMeta(fieldInfo, centroidOffset, 0, new long[0], null); + CodecUtil.writeFooter(centroidTemp); + IOUtils.close(centroidTemp); + return; + } + CodecUtil.writeFooter(centroidTemp); + IOUtils.close(centroidTemp); + centroidOffset = ivfCentroids.alignFilePointer(Float.BYTES); + try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { + ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength()); + centroidLength = ivfCentroids.getFilePointer() - centroidOffset; + centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, globalCentroid); + assert centroidAssignmentScorer.size() == numCentroids; + // build a float vector values with random access + // build centroids + final long[] offsets = buildAndWritePostingsLists( + fieldInfo, + centroidAssignmentScorer, + floatVectorValues, + ivfClusters, + mergeState + ); + // write posting lists + + // TODO handle this correctly by creating new centroid + if (vectorCount == 0 && offsets.length > 0) { + throw new IllegalStateException("No global centroid found for field: " + fieldInfo.name); + } + assert offsets.length == centroidAssignmentScorer.size(); + writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + } + } finally { + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + } + } finally { + IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + } } - } - }; - } - - private static int writeFloatVectorValues( - FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues) - throws IOException { - int numVectors = 0; - final ByteBuffer buffer = - ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES) - .order(ByteOrder.LITTLE_ENDIAN); - final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); - for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { - numVectors++; - float[] vector = floatVectorValues.vectorValue(iterator.index()); - out.writeInt(iterator.docID()); - buffer.asFloatBuffer().put(vector); - out.writeBytes(buffer.array(), buffer.array().length); } - return numVectors; - } - - private void writeMeta( - FieldInfo field, - long centroidOffset, - long centroidLength, - long[] offsets, - float[] globalCentroid) - throws IOException { - ivfMeta.writeInt(field.number); - ivfMeta.writeInt(field.getVectorEncoding().ordinal()); - ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); - ivfMeta.writeLong(centroidOffset); - ivfMeta.writeLong(centroidLength); - ivfMeta.writeVInt(offsets.length); - for (long offset : offsets) { - ivfMeta.writeLong(offset); + + private static FloatVectorValues getFloatVectorValues(FieldInfo fieldInfo, IndexInput randomAccessInput, int numVectors) { + if (numVectors == 0) { + return FloatVectorValues.fromFloats(List.of(), fieldInfo.getVectorDimension()); + } + final long length = (long) Float.BYTES * fieldInfo.getVectorDimension() + Integer.BYTES; + final float[] vector = new float[fieldInfo.getVectorDimension()]; + return new FloatVectorValues() { + @Override + public float[] vectorValue(int ord) throws IOException { + randomAccessInput.seek(ord * length + Integer.BYTES); + randomAccessInput.readFloats(vector, 0, vector.length); + return vector; + } + + @Override + public FloatVectorValues copy() { + return this; + } + + @Override + public int dimension() { + return fieldInfo.getVectorDimension(); + } + + @Override + public int size() { + return numVectors; + } + + @Override + public int ordToDoc(int ord) { + try { + randomAccessInput.seek(ord * length); + return randomAccessInput.readInt(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + }; + } + + private static int writeFloatVectorValues(FieldInfo fieldInfo, IndexOutput out, FloatVectorValues floatVectorValues) + throws IOException { + int numVectors = 0; + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + final KnnVectorValues.DocIndexIterator iterator = floatVectorValues.iterator(); + for (int docV = iterator.nextDoc(); docV != NO_MORE_DOCS; docV = iterator.nextDoc()) { + numVectors++; + float[] vector = floatVectorValues.vectorValue(iterator.index()); + out.writeInt(iterator.docID()); + buffer.asFloatBuffer().put(vector); + out.writeBytes(buffer.array(), buffer.array().length); + } + return numVectors; } - if (offsets.length > 0) { - final ByteBuffer buffer = - ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); - buffer.asFloatBuffer().put(globalCentroid); - ivfMeta.writeBytes(buffer.array(), buffer.array().length); - ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid))); + + private void writeMeta(FieldInfo field, long centroidOffset, long centroidLength, long[] offsets, float[] globalCentroid) + throws IOException { + ivfMeta.writeInt(field.number); + ivfMeta.writeInt(field.getVectorEncoding().ordinal()); + ivfMeta.writeInt(distFuncToOrd(field.getVectorSimilarityFunction())); + ivfMeta.writeLong(centroidOffset); + ivfMeta.writeLong(centroidLength); + ivfMeta.writeVInt(offsets.length); + for (long offset : offsets) { + ivfMeta.writeLong(offset); + } + if (offsets.length > 0) { + final ByteBuffer buffer = ByteBuffer.allocate(globalCentroid.length * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + buffer.asFloatBuffer().put(globalCentroid); + ivfMeta.writeBytes(buffer.array(), buffer.array().length); + ivfMeta.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(globalCentroid, globalCentroid))); + } } - } - private static int distFuncToOrd(VectorSimilarityFunction func) { - for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { - if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { - return (byte) i; - } + private static int distFuncToOrd(VectorSimilarityFunction func) { + for (int i = 0; i < SIMILARITY_FUNCTIONS.size(); i++) { + if (SIMILARITY_FUNCTIONS.get(i).equals(func)) { + return (byte) i; + } + } + throw new IllegalArgumentException("invalid distance function: " + func); } - throw new IllegalArgumentException("invalid distance function: " + func); - } - - @Override - public final void finish() throws IOException { - rawVectorDelegate.finish(); - if (ivfMeta != null) { - // write end of fields marker - ivfMeta.writeInt(-1); - CodecUtil.writeFooter(ivfMeta); + + @Override + public final void finish() throws IOException { + rawVectorDelegate.finish(); + if (ivfMeta != null) { + // write end of fields marker + ivfMeta.writeInt(-1); + CodecUtil.writeFooter(ivfMeta); + } + if (ivfCentroids != null) { + CodecUtil.writeFooter(ivfCentroids); + } + if (ivfClusters != null) { + CodecUtil.writeFooter(ivfClusters); + } } - if (ivfCentroids != null) { - CodecUtil.writeFooter(ivfCentroids); + + @Override + public final void close() throws IOException { + IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters); } - if (ivfClusters != null) { - CodecUtil.writeFooter(ivfClusters); + + @Override + public final long ramBytesUsed() { + return rawVectorDelegate.ramBytesUsed(); } - } - - @Override - public final void close() throws IOException { - IOUtils.close(rawVectorDelegate, ivfMeta, ivfCentroids, ivfClusters); - } - - @Override - public final long ramBytesUsed() { - return rawVectorDelegate.ramBytesUsed(); - } - - private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter delegate) {} - - interface CentroidAssignmentScorer { - CentroidAssignmentScorer EMPTY = - new CentroidAssignmentScorer() { - @Override - public int size() { - return 0; - } - - @Override - public float[] centroid(int centroidOrdinal) { - throw new IllegalStateException("No centroids"); - } - - @Override - public float score(int centroidOrdinal) { - throw new IllegalStateException("No centroids"); - } - - @Override - public void setScoringVector(float[] vector) { - throw new IllegalStateException("No centroids"); - } + + private record FieldWriter(FieldInfo fieldInfo, FlatFieldVectorsWriter delegate) {} + + interface CentroidAssignmentScorer { + CentroidAssignmentScorer EMPTY = new CentroidAssignmentScorer() { + @Override + public int size() { + return 0; + } + + @Override + public float[] centroid(int centroidOrdinal) { + throw new IllegalStateException("No centroids"); + } + + @Override + public float score(int centroidOrdinal) { + throw new IllegalStateException("No centroids"); + } + + @Override + public void setScoringVector(float[] vector) { + throw new IllegalStateException("No centroids"); + } }; - int size(); + int size(); - float[] centroid(int centroidOrdinal) throws IOException; + float[] centroid(int centroidOrdinal) throws IOException; - void setScoringVector(float[] vector); + void setScoringVector(float[] vector); - float score(int centroidOrdinal) throws IOException; - } + float score(int centroidOrdinal) throws IOException; + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java index ce6f8d07baeb6..b4156988d2788 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java @@ -27,136 +27,120 @@ */ class NeighborQueue { - private enum Order { - MIN_HEAP { - @Override - long apply(long v) { - return v; - } - }, - MAX_HEAP { - @Override - long apply(long v) { - // This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It - // needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa. - return -1 - v; - } - }; - - abstract long apply(long v); - } - - private final LongHeap heap; - private final Order order; - - NeighborQueue(int initialSize, boolean maxHeap) { - this.heap = new LongHeap(initialSize); - this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP; - } - - /** - * @return the number of elements in the heap - */ - public int size() { - return heap.size(); - } - - /** - * Adds a new graph arc, extending the storage as needed. - * - * @param newNode the neighbor node id - * @param newScore the score of the neighbor, relative to some other node - */ - public void add(int newNode, float newScore) { - heap.push(encode(newNode, newScore)); - } - - /** - * If the heap is not full (size is less than the initialSize provided to the constructor), adds a - * new node-and-score element. If the heap is full, compares the score against the current top - * score, and replaces the top element if newScore is better than (greater than unless the heap is - * reversed), the current top score. - * - * @param newNode the neighbor node id - * @param newScore the score of the neighbor, relative to some other node - */ - public boolean insertWithOverflow(int newNode, float newScore) { - return heap.insertWithOverflow(encode(newNode, newScore)); - } - - /** - * Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule - * that when two scores are equal, the smaller node ID must win. - * - *

    The most significant 32 bits represent the float score, encoded as a sortable int. - * - *

    The least significant 32 bits represent the node ID. - * - *

    The bits representing the node ID are complemented to guarantee the win for the smaller node - * Id. - * - *

    The AND with 0xFFFFFFFFL (a long with first 32 bits as 1) is necessary to obtain a long that - * has: - *

  • The most significant 32 bits set to 0 - *
  • The least significant 32 bits represent the node ID. - * - * @param node the node ID - * @param score the node score - * @return the encoded score, node ID - */ - private long encode(int node, float score) { - return order.apply( - (((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); - } - - private float decodeScore(long heapValue) { - return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32)); - } - - private int decodeNodeId(long heapValue) { - return (int) ~(order.apply(heapValue)); - } - - /** Removes the top element and returns its node id. */ - public int pop() { - return decodeNodeId(heap.pop()); - } - - public void consumeNodes(int[] dest) { - if (dest.length < size()) { - throw new IllegalArgumentException( - "Destination array is too small. Expected at least " + size() + " elements."); + private enum Order { + MIN_HEAP { + @Override + long apply(long v) { + return v; + } + }, + MAX_HEAP { + @Override + long apply(long v) { + // This cannot be just `-v` since Long.MIN_VALUE doesn't have a positive counterpart. It + // needs a function that returns MAX_VALUE for MIN_VALUE and vice-versa. + return -1 - v; + } + }; + + abstract long apply(long v); } - for (int i = 0; i < size(); i++) { - dest[i] = decodeNodeId(heap.get(i + 1)); + + private final LongHeap heap; + private final Order order; + + NeighborQueue(int initialSize, boolean maxHeap) { + this.heap = new LongHeap(initialSize); + this.order = maxHeap ? Order.MAX_HEAP : Order.MIN_HEAP; + } + + /** + * @return the number of elements in the heap + */ + public int size() { + return heap.size(); + } + + /** + * Adds a new graph arc, extending the storage as needed. + * + * @param newNode the neighbor node id + * @param newScore the score of the neighbor, relative to some other node + */ + public void add(int newNode, float newScore) { + heap.push(encode(newNode, newScore)); } - } - public int consumeNodesAndScoresMin(int[] dest, float[] scores) { - if (dest.length < size() || scores.length < size()) { - throw new IllegalArgumentException( - "Destination array is too small. Expected at least " + size() + " elements."); + /** + * If the heap is not full (size is less than the initialSize provided to the constructor), adds a + * new node-and-score element. If the heap is full, compares the score against the current top + * score, and replaces the top element if newScore is better than (greater than unless the heap is + * reversed), the current top score. + * + * @param newNode the neighbor node id + * @param newScore the score of the neighbor, relative to some other node + */ + public boolean insertWithOverflow(int newNode, float newScore) { + return heap.insertWithOverflow(encode(newNode, newScore)); } - float bestScore = Float.POSITIVE_INFINITY; - int bestIdx = 0; - for (int i = 0; i < size(); i++) { - long heapValue = heap.get(i + 1); - scores[i] = decodeScore(heapValue); - dest[i] = decodeNodeId(heapValue); - if (scores[i] < bestScore) { - bestScore = scores[i]; - bestIdx = i; - } + + /** + * Encodes the node ID and its similarity score as long, preserving the Lucene tie-breaking rule + * that when two scores are equal, the smaller node ID must win. + * @param node the node ID + * @param score the node score + * @return the encoded score, node ID + */ + private long encode(int node, float score) { + return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); } - return bestIdx; - } - public void clear() { - heap.clear(); - } + private float decodeScore(long heapValue) { + return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32)); + } + + private int decodeNodeId(long heapValue) { + return (int) ~(order.apply(heapValue)); + } - @Override - public String toString() { - return "Neighbors[" + heap.size() + "]"; - } + /** Removes the top element and returns its node id. */ + public int pop() { + return decodeNodeId(heap.pop()); + } + + public void consumeNodes(int[] dest) { + if (dest.length < size()) { + throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements."); + } + for (int i = 0; i < size(); i++) { + dest[i] = decodeNodeId(heap.get(i + 1)); + } + } + + public int consumeNodesAndScoresMin(int[] dest, float[] scores) { + if (dest.length < size() || scores.length < size()) { + throw new IllegalArgumentException("Destination array is too small. Expected at least " + size() + " elements."); + } + float bestScore = Float.POSITIVE_INFINITY; + int bestIdx = 0; + for (int i = 0; i < size(); i++) { + long heapValue = heap.get(i + 1); + scores[i] = decodeScore(heapValue); + dest[i] = decodeNodeId(heapValue); + if (scores[i] < bestScore) { + bestScore = scores[i]; + bestIdx = i; + } + } + return bestIdx; + } + + public void clear() { + heap.clear(); + } + + @Override + public String toString() { + return "Neighbors[" + heap.size() + "]"; + } } diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index cef8d09980814..14e68029abc3b 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -7,3 +7,4 @@ org.elasticsearch.index.codec.vectors.es816.ES816BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat +org.elasticsearch.index.codec.vectors.IVFVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java new file mode 100644 index 0000000000000..a8ecb3fce066e --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.index.codec.vectors; + +import com.carrotsearch.randomizedtesting.generators.RandomPicks; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.junit.Before; + +import java.util.List; + +public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + format = new IVFVectorsFormat(random().nextInt(10, 1000)); + super.setUp(); + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return RandomPicks.randomFrom( + random(), + List.of( + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT + ) + ); + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testSearchWithVisitedLimit() { + // ivf doesn't enforce visitation limit + } + + @Override + protected Codec getCodec() { + assumeTrue("IVF format flat enabled", IVFVectorsFormat.IVF_FORMAT_FEATURE_FLAG.isEnabled()); + return TestUtil.alwaysKnnVectorsFormat(format); + } +} From 88e0f8b89d46cf560a449a702ee55f50c2aa1a98 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:39:16 -0400 Subject: [PATCH 03/11] iter --- .../vectors/DefaultIVFVectorsWriter.java | 71 ++----------------- .../index/codec/vectors/IVFVectorsWriter.java | 8 +-- 2 files changed, 11 insertions(+), 68 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 45fdaa9c3dec0..cb1bc33c4e79c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -389,7 +389,9 @@ long[] buildAndWritePostingsLists( clusters[i] = new IntArrayList(floatVectorValues.size() / centroidAssignmentScorer.size() / 4); } long nanoTime = System.nanoTime(); - assignCentroidsMerge(centroidAssignmentScorer, floatVectorValues, clusters); + // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? + // We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing + assignCentroids(centroidAssignmentScorer, floatVectorValues, clusters); if (mergeState.infoStream.isEnabled(IVF_VECTOR_COMPONENT)) { mergeState.infoStream.message(IVF_VECTOR_COMPONENT, "assignCentroids time ms: " + ((System.nanoTime() - nanoTime) / 1000000.0)); } @@ -413,8 +415,8 @@ long[] buildAndWritePostingsLists( postingsOutput.writeVInt(size); postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); // TODO we might want to consider putting the docIds in a separate file - // to aid with only having to fetch vectors from slower storage when they are required - // keeping them in the same file indicates we pull the entire file into cache + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster.get(j)), size, postingsOutput); writePostingList(cluster, postingsOutput, binarizedByteVectorValues); } @@ -458,14 +460,10 @@ private static void printClusterQualityStatistics(IntArrayList[] clusters, InfoS } static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues vectors, IntArrayList[] clusters) throws IOException { - short numCentroids = (short) scorer.size(); - // If soar > 0, then we actually need to apply the projection, otherwise, its just the second - // nearest centroid + int numCentroids = scorer.size(); // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); - // if lambda is `0`, that just means overspill to the second nearest, so we will only check the - // second nearest NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); float[] scratch = new float[vectors.dimension()]; @@ -482,7 +480,7 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v // pop the best int sz = neighborsToCheck.size(); int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); - // TODO yikes.... + // reset the ordScoreIterator as it has consumed the ords and scores ordScoreIterator.idx = sz; bestScore = ordScoreIterator.getScore(best); bestCentroid = ordScoreIterator.getOrd(best); @@ -508,61 +506,6 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v } } - static void assignCentroidsMerge( - CentroidAssignmentScorer scorer, - FloatVectorValues vectors, - IntArrayList[] clusters - ) throws IOException { - int numCentroids = scorer.size(); - // If soar > 0, then we actually need to apply the projection, otherwise, its just the second - // nearest centroid - // we at most will look at the EXT_SOAR_LIMIT_CHECK_RATIO nearest centroids if possible - int soarToCheck = (int) (numCentroids * EXT_SOAR_LIMIT_CHECK_RATIO); - int soarClusterCheckCount = Math.min(numCentroids - 1, soarToCheck); - // TODO is this the right to check? - // If cluster quality is higher, maybe we can reduce this... - NeighborQueue neighborsToCheck = new NeighborQueue(soarClusterCheckCount + 1, true); - OrdScoreIterator ordScoreIterator = new OrdScoreIterator(soarClusterCheckCount + 1); - float[] scratch = new float[vectors.dimension()]; - // Can we do a pre-filter by finding the nearest centroids to the original vector centroids? - // We need to be careful on vecOrd vs. doc as we need random access to the raw vector for posting list writing - for (int vecOrd = 0; vecOrd < vectors.size(); vecOrd++) { - float[] vector = vectors.vectorValue(vecOrd); - scorer.setScoringVector(vector); - int bestCentroid = 0; - float bestScore = Float.MAX_VALUE; - if (numCentroids > 1) { - for (short c = 0; c < numCentroids; c++) { - float squareDist = scorer.score(c); - neighborsToCheck.insertWithOverflow(c, squareDist); - } - int centroidCount = neighborsToCheck.size(); - int bestIdx = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); - ordScoreIterator.idx = centroidCount; - bestCentroid = ordScoreIterator.getOrd(bestIdx); - bestScore = ordScoreIterator.getScore(bestIdx); - } - if (clusters[bestCentroid] == null) { - clusters[bestCentroid] = new IntArrayList(16); - } - clusters[bestCentroid].add(vecOrd); - if (soarClusterCheckCount > 0) { - assignCentroidSOAR( - ordScoreIterator, - vecOrd, - bestCentroid, - scorer.centroid(bestCentroid), - bestScore, - scratch, - scorer, - vectors, - clusters - ); - } - neighborsToCheck.clear(); - } - } - static void assignCentroidSOAR( OrdScoreIterator centroidsToCheck, int vecOrd, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 07e0a389c47d1..0323aa6abfd08 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -262,7 +262,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { name = out.getName(); // TODO do this better, we shouldn't have to write to a temp file, we should be able to - // to just from the merged vector values. + // to just from the merged vector values. numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); success = true; } finally { @@ -274,9 +274,9 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro int vectorCount = 0; for (int idx = 0; idx < mergeState.knnVectorsReaders.length; idx++) { if (mergeState.fieldInfos[idx] == null - || mergeState.fieldInfos[idx].hasVectorValues() == false - || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name) == null - || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name).hasVectorValues() == false) { + || mergeState.fieldInfos[idx].hasVectorValues() == false + || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name) == null + || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name).hasVectorValues() == false) { continue; } KnnVectorsReader knnReaders = mergeState.knnVectorsReaders[idx]; From 6fa5234d39ba07ee457de8acdad833c6d3b44cef Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 29 Apr 2025 14:52:15 -0400 Subject: [PATCH 04/11] iter --- .../index/codec/vectors/IVFVectorsFormat.java | 20 +++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java index 230d7743726eb..0a41a1017debe 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -26,11 +26,23 @@ * Codec format for Inverted File Vector indexes. This index expects to break the dimensional space * into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters * are represented by centroids. - *

    - * Each posting list is individual quantized and stored in-line allowing for fast block scoring. + * The vector quantization format used here is a per-vector optimized scalar quantization. Also see {@link + * OptimizedScalarQuantizer}. Some of key features are: + * + * The format is stored in three files: + * + *

    .cenivf (centroid data) file

    + *

    Which stores the raw and quantized centroid vectors. + * + *

    .clivf (cluster data) file

    + * + *

    Stores the quantized vectors for each cluster, inline and stored in blocks. Additionally, the docIds of + * each vector is stored. + * + *

    .mivf (centroid metadata) file

    + * + *

    Stores metadata including the number of centroids and their offsets in the clivf file

    * - *

    THe index is searcher by looking for the closest centroids to our vector query and then - * scoring the vectors in the posting list of the closest centroids. */ public class IVFVectorsFormat extends KnnVectorsFormat { From 3e61aee64a4d1668ee415a1455521fa5fb0e09be Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Apr 2025 08:02:57 -0400 Subject: [PATCH 05/11] iter --- .../org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java index 486919f10bcf5..85ca13e6e8754 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java @@ -17,7 +17,7 @@ import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; import org.elasticsearch.common.logging.LogConfigurator; -import org.elasticsearch.simdvec.internal.vectorization.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider; import org.openjdk.jmh.annotations.Benchmark; import org.openjdk.jmh.annotations.BenchmarkMode; From 2c4c806ca9d9cded429d04cbb35521f9a08d6ca5 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Apr 2025 09:09:18 -0400 Subject: [PATCH 06/11] addressing forbidden api stuffs --- .../index/codec/vectors/IVFVectorsReader.java | 2 +- .../index/codec/vectors/IVFVectorsWriter.java | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java index 6d4e6602cbac1..4319c7c47c82a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsReader.java @@ -30,8 +30,8 @@ import org.apache.lucene.util.BitSet; import org.apache.lucene.util.Bits; import org.apache.lucene.util.FixedBitSet; -import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.hnsw.NeighborQueue; +import org.elasticsearch.core.IOUtils; import java.io.IOException; import java.util.function.IntPredicate; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 0323aa6abfd08..8d478c494c001 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -29,9 +29,10 @@ import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.IOUtils; import org.apache.lucene.util.InfoStream; import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.SuppressForbidden; import java.io.IOException; import java.io.UncheckedIOException; @@ -250,6 +251,7 @@ static IVFVectorsReader getIVFReader(KnnVectorsReader vectorsReader, String fiel } @Override + @SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)") public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) throws IOException { rawVectorDelegate.mergeOneField(fieldInfo, mergeState); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { @@ -267,7 +269,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro success = true; } finally { if (success == false && name != null) { - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); } } float[] globalCentroid = new float[fieldInfo.getVectorDimension()]; @@ -318,7 +320,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro } finally { if (success == false && centroidTempName != null) { IOUtils.closeWhileHandlingException(centroidTemp); - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); } } try { @@ -356,11 +358,10 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); } } finally { - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, centroidTempName); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name, centroidTempName); } } finally { - IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); } } } From e069524126c38064fea6ba9cb8ddd5a9b7821024 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Apr 2025 10:43:13 -0400 Subject: [PATCH 07/11] fixing headers --- .../index/codec/vectors/IVFVectorsFormatTests.java | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java index a8ecb3fce066e..0f7bdce9c3e3d 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -1,8 +1,10 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". */ package org.elasticsearch.index.codec.vectors; From 42800272ce6ee8931bb6bcdfbe719501434aca8b Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 30 Apr 2025 13:52:23 -0400 Subject: [PATCH 08/11] fixing inc visit logic --- .../index/codec/vectors/DefaultIVFVectorsReader.java | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index f555ce2ba9113..71fe1cf68807c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -390,7 +390,9 @@ public int visit(KnnCollector knnCollector) throws IOException { knnCollector.collect(doc, score); } } - knnCollector.incVisitedCount(scoredDocs); + if (scoredDocs > 0) { + knnCollector.incVisitedCount(scoredDocs); + } return scoredDocs; } From f6bc2953dc77c8daf73ed76be524709af7f71f6d Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Fri, 2 May 2025 11:06:36 -0400 Subject: [PATCH 09/11] simplifying desired clusters --- .../index/codec/vectors/DefaultIVFVectorsWriter.java | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index cb1bc33c4e79c..dd255948ceaf0 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -65,10 +65,7 @@ CentroidAssignmentScorer calculateAndWriteCentroids( } // calculate the centroids int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - int desiredClusters = (int) Math.max(maxNumClusters / 16.0, Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters)); - if (floatVectorValues.size() / desiredClusters > vectorPerCluster) { - desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - } + int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters); final KMeans.Results kMeans = KMeans.cluster( floatVectorValues, desiredClusters, From a4f2c28b769034588e8ff0ef4da6aefcdd6016ee Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 5 May 2025 09:04:38 -0400 Subject: [PATCH 10/11] fixing test failure, addressing pr comments --- .../vectors/DefaultIVFVectorsWriter.java | 140 ++++++++++++------ .../index/codec/vectors/IVFVectorsWriter.java | 72 +++------ 2 files changed, 118 insertions(+), 94 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index dd255948ceaf0..1c431b01e611c 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -223,44 +223,17 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo } } - record SegmentCentroid(int segment, int centroid, int centroidSize) {} - - @Override - protected int calculateAndWriteCentroids( + static float[][] gatherInitCentroids( + List centroidList, + List segmentCentroids, + int desiredClusters, FieldInfo fieldInfo, - FloatVectorValues floatVectorValues, - IndexOutput temporaryCentroidOutput, - MergeState mergeState, - float[] globalCentroid + MergeState mergeState ) throws IOException { - if (floatVectorValues.size() == 0) { - return 0; + if (centroidList.size() == 0) { + return null; } - int desiredClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; - // init centroids from merge state - List centroidList = new ArrayList<>(); - List segmentCentroids = new ArrayList<>(desiredClusters); - - int segmentIdx = 0; long startTime = System.nanoTime(); - for (var reader : mergeState.knnVectorsReaders) { - IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); - if (ivfVectorsReader == null) { - continue; - } - - FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); - if (centroid == null) { - continue; - } - centroidList.add(centroid); - for (int i = 0; i < centroid.size(); i++) { - int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); - segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); - } - segmentIdx++; - } - // sort centroid list by floatvector size FloatVectorValues baseSegment = centroidList.get(0); for (var l : centroidList) { @@ -334,6 +307,9 @@ protected int calculateAndWriteCentroids( sum[label - 1] += segmentCentroid.centroidSize; } for (int i = 0; i < initCentroids.length; i++) { + if (sum[i] == 0 || sum[i] == 1) { + continue; + } for (int j = 0; j < initCentroids[i].length; j++) { initCentroids[i][j] /= sum[i]; } @@ -348,6 +324,67 @@ protected int calculateAndWriteCentroids( "Gathered initCentroids:" + initCentroids.length + " for desired: " + desiredClusters ); } + return initCentroids; + } + + record SegmentCentroid(int segment, int centroid, int centroidSize) {} + + /** + * Calculate the centroids for the given field and write them to the given + * temporary centroid output. + * When merging, we first bootstrap the KMeans algorithm with the centroids contained in the merging segments. + * To prevent centroids that are too similar from having an outsized impact, all centroids that are closer than + * the largest segments intra-cluster distance are merged into a single centroid. + * The resulting centroids are then used to initialize the KMeans algorithm. + * + * @param fieldInfo merging field info + * @param floatVectorValues the float vector values to merge + * @param temporaryCentroidOutput the temporary centroid output + * @param mergeState the merge state + * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids + * @return the number of centroids written + * @throws IOException if an I/O error occurs + */ + @Override + protected int calculateAndWriteCentroids( + FieldInfo fieldInfo, + FloatVectorValues floatVectorValues, + IndexOutput temporaryCentroidOutput, + MergeState mergeState, + float[] globalCentroid + ) throws IOException { + if (floatVectorValues.size() == 0) { + return 0; + } + int maxNumClusters = ((floatVectorValues.size() - 1) / vectorPerCluster) + 1; + int desiredClusters = (int) Math.max(Math.sqrt(floatVectorValues.size()), maxNumClusters); + // init centroids from merge state + List centroidList = new ArrayList<>(); + List segmentCentroids = new ArrayList<>(desiredClusters); + + int segmentIdx = 0; + for (var reader : mergeState.knnVectorsReaders) { + IVFVectorsReader ivfVectorsReader = IVFVectorsFormat.getIVFReader(reader, fieldInfo.name); + if (ivfVectorsReader == null) { + continue; + } + + FloatVectorValues centroid = ivfVectorsReader.getCentroids(fieldInfo); + if (centroid == null) { + continue; + } + centroidList.add(centroid); + for (int i = 0; i < centroid.size(); i++) { + int size = ivfVectorsReader.centroidSize(fieldInfo.name, i); + if (size == 0) { + continue; + } + segmentCentroids.add(new SegmentCentroid(segmentIdx, i, size)); + } + segmentIdx++; + } + + float[][] initCentroids = gatherInitCentroids(centroidList, segmentCentroids, desiredClusters, fieldInfo, mergeState); // FIXME: run a custom version of KMeans that is just better... long nanoTime = System.nanoTime(); @@ -369,6 +406,15 @@ protected int calculateAndWriteCentroids( float[][] centroids = kMeans.centroids(); // write them + // calculate the global centroid from all the centroids: + for (float[] centroid : centroids) { + for (int j = 0; j < centroid.length; j++) { + globalCentroid[j] += centroid[j]; + } + } + for (int j = 0; j < globalCentroid.length; j++) { + globalCentroid[j] /= centroids.length; + } writeCentroids(centroids, fieldInfo, globalCentroid, temporaryCentroidOutput); return centroids.length; } @@ -477,14 +523,11 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v // pop the best int sz = neighborsToCheck.size(); int best = neighborsToCheck.consumeNodesAndScoresMin(ordScoreIterator.ords, ordScoreIterator.scores); - // reset the ordScoreIterator as it has consumed the ords and scores - ordScoreIterator.idx = sz; + // Set the size to the number of neighbors we actually found + ordScoreIterator.setSize(sz); bestScore = ordScoreIterator.getScore(best); bestCentroid = ordScoreIterator.getOrd(best); } - if (clusters[bestCentroid] == null) { - clusters[bestCentroid] = new IntArrayList(16); - } clusters[bestCentroid].add(docID); if (soarClusterCheckCount > 0) { assignCentroidSOAR( @@ -495,7 +538,7 @@ static void assignCentroids(CentroidAssignmentScorer scorer, FloatVectorValues v bestScore, scratch, scorer, - vectors, + vector, clusters ); } @@ -511,10 +554,9 @@ static void assignCentroidSOAR( float bestScore, float[] scratch, CentroidAssignmentScorer scorer, - FloatVectorValues vectors, + float[] vector, IntArrayList[] clusters ) throws IOException { - float[] vector = vectors.vectorValue(vecOrd); ESVectorUtil.subtract(vector, bestCentroid, scratch); int bestSecondaryCentroid = -1; float minDist = Float.MAX_VALUE; @@ -546,6 +588,14 @@ static class OrdScoreIterator { this.scores = new float[size]; } + int setSize(int size) { + if (size > ords.length) { + throw new IllegalArgumentException("size must be <= " + ords.length); + } + this.idx = size; + return size; + } + int getOrd(int idx) { return ords[idx]; } @@ -606,7 +656,7 @@ static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer private final int dimension; private final float[] scratch; private float[] q; - private final long centroidByteSize; + private final long rawCentroidOffset; private int currOrd = -1; OffHeapCentroidAssignmentScorer(IndexInput centroidsInput, int numCentroids, FieldInfo info) { @@ -614,7 +664,7 @@ static class OffHeapCentroidAssignmentScorer implements CentroidAssignmentScorer this.numCentroids = numCentroids; this.dimension = info.getVectorDimension(); this.scratch = new float[dimension]; - this.centroidByteSize = dimension + 3 * Float.BYTES + Short.BYTES; + this.rawCentroidOffset = (dimension + 3 * Float.BYTES + Short.BYTES) * numCentroids; } @Override @@ -627,7 +677,7 @@ public float[] centroid(int centroidOrdinal) throws IOException { if (centroidOrdinal == currOrd) { return scratch; } - centroidsInput.seek(numCentroids * centroidByteSize + (long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.seek(rawCentroidOffset + (long) centroidOrdinal * dimension * Float.BYTES); centroidsInput.readFloats(scratch, 0, dimension); this.currOrd = centroidOrdinal; return scratch; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index 8d478c494c001..4e9c4ee47e3f6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -192,7 +192,6 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { floatVectorValues, ivfClusters ); - // write posting lists writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); } } @@ -256,54 +255,25 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro rawVectorDelegate.mergeOneField(fieldInfo, mergeState); if (fieldInfo.getVectorEncoding().equals(VectorEncoding.FLOAT32)) { final int numVectors; - String name = null; + String tempRawVectorsFileName = null; boolean success = false; // build a float vector values with random access. In order to do that we dump the vectors to // a temporary file // and write the docID follow by the vector try (IndexOutput out = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "ivf_", IOContext.DEFAULT)) { - name = out.getName(); + tempRawVectorsFileName = out.getName(); // TODO do this better, we shouldn't have to write to a temp file, we should be able to - // to just from the merged vector values. + // to just from the merged vector values, the tricky part is the random access. numVectors = writeFloatVectorValues(fieldInfo, out, MergedVectorValues.mergeFloatVectorValues(fieldInfo, mergeState)); + CodecUtil.writeFooter(out); success = true; } finally { - if (success == false && name != null) { - org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + if (success == false && tempRawVectorsFileName != null) { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); } } - float[] globalCentroid = new float[fieldInfo.getVectorDimension()]; - int vectorCount = 0; - for (int idx = 0; idx < mergeState.knnVectorsReaders.length; idx++) { - if (mergeState.fieldInfos[idx] == null - || mergeState.fieldInfos[idx].hasVectorValues() == false - || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name) == null - || mergeState.fieldInfos[idx].fieldInfo(fieldInfo.name).hasVectorValues() == false) { - continue; - } - KnnVectorsReader knnReaders = mergeState.knnVectorsReaders[idx]; - IVFVectorsReader ivfReader = getIVFReader(knnReaders, fieldInfo.name); - if (ivfReader != null) { - FloatVectorValues floatVectorValues = knnReaders.getFloatVectorValues(fieldInfo.name); - if (floatVectorValues == null) { - continue; - } - int numVecs = floatVectorValues.size(); - float[] readerGlobalCentroid = ivfReader.getGlobalCentroid(fieldInfo); - if (readerGlobalCentroid != null) { - vectorCount += numVecs; - for (int i = 0; i < globalCentroid.length; i++) { - globalCentroid[i] += readerGlobalCentroid[i] * numVecs; - } - } - } - } - if (vectorCount > 0) { - for (int i = 0; i < globalCentroid.length; i++) { - globalCentroid[i] /= vectorCount; - } - } - try (IndexInput in = mergeState.segmentInfo.dir.openInput(name, IOContext.DEFAULT)) { + try (IndexInput in = mergeState.segmentInfo.dir.openInput(tempRawVectorsFileName, IOContext.DEFAULT)) { + float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; final FloatVectorValues floatVectorValues = getFloatVectorValues(fieldInfo, in, numVectors); success = false; CentroidAssignmentScorer centroidAssignmentScorer; @@ -315,7 +285,13 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro try { centroidTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "civf_", IOContext.DEFAULT); centroidTempName = centroidTemp.getName(); - numCentroids = calculateAndWriteCentroids(fieldInfo, floatVectorValues, centroidTemp, mergeState, globalCentroid); + numCentroids = calculateAndWriteCentroids( + fieldInfo, + floatVectorValues, + centroidTemp, + mergeState, + calculatedGlobalCentroid + ); success = true; } finally { if (success == false && centroidTempName != null) { @@ -337,7 +313,7 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro try (IndexInput centroidInput = mergeState.segmentInfo.dir.openInput(centroidTempName, IOContext.DEFAULT)) { ivfCentroids.copyBytes(centroidInput, centroidInput.length() - CodecUtil.footerLength()); centroidLength = ivfCentroids.getFilePointer() - centroidOffset; - centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, globalCentroid); + centroidAssignmentScorer = createCentroidScorer(centroidInput, numCentroids, fieldInfo, calculatedGlobalCentroid); assert centroidAssignmentScorer.size() == numCentroids; // build a float vector values with random access // build centroids @@ -348,20 +324,18 @@ public final void mergeOneField(FieldInfo fieldInfo, MergeState mergeState) thro ivfClusters, mergeState ); - // write posting lists - - // TODO handle this correctly by creating new centroid - if (vectorCount == 0 && offsets.length > 0) { - throw new IllegalStateException("No global centroid found for field: " + fieldInfo.name); - } assert offsets.length == centroidAssignmentScorer.size(); - writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); + writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); } } finally { - org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name, centroidTempName); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions( + mergeState.segmentInfo.dir, + tempRawVectorsFileName, + centroidTempName + ); } } finally { - org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, name); + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, tempRawVectorsFileName); } } } From a925b5486ba228ef09de3cfe24bc3e294ebf420c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Wed, 7 May 2025 12:49:02 -0400 Subject: [PATCH 11/11] addressing comments --- .../vectors/DefaultIVFVectorsReader.java | 6 +- .../index/codec/vectors/NeighborQueue.java | 13 ++ .../codec/vectors/NeighborQueueTests.java | 119 ++++++++++++++++++ 3 files changed, 135 insertions(+), 3 deletions(-) create mode 100644 server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 71fe1cf68807c..e09cf474d09ea 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -80,11 +80,11 @@ public int size() { @Override public float[] centroid(int centroidOrdinal) throws IOException { - readQuantizedCentroid(centroidOrdinal); + readQuantizedAndRawCentroid(centroidOrdinal); return centroid; } - private void readQuantizedCentroid(int centroidOrdinal) throws IOException { + private void readQuantizedAndRawCentroid(int centroidOrdinal) throws IOException { if (centroidOrdinal == currentCentroid) { return; } @@ -97,7 +97,7 @@ private void readQuantizedCentroid(int centroidOrdinal) throws IOException { @Override public float score(int centroidOrdinal) throws IOException { - readQuantizedCentroid(centroidOrdinal); + readQuantizedAndRawCentroid(centroidOrdinal); return int4QuantizedScore( quantized, queryParams, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java index b4156988d2788..f27e85d46cddb 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/NeighborQueue.java @@ -95,6 +95,19 @@ private long encode(int node, float score) { return order.apply((((long) NumericUtils.floatToSortableInt(score)) << 32) | (0xFFFFFFFFL & ~node)); } + /** Returns the top element's node id. */ + int topNode() { + return decodeNodeId(heap.top()); + } + + /** + * Returns the top element's node score. For the min heap this is the minimum score. For the max + * heap this is the maximum score. + */ + float topScore() { + return decodeScore(heap.top()); + } + private float decodeScore(long heapValue) { return NumericUtils.sortableIntToFloat((int) (order.apply(heapValue) >> 32)); } diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java new file mode 100644 index 0000000000000..7238f58d746dc --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/NeighborQueueTests.java @@ -0,0 +1,119 @@ +/* + * @notice + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * Modifications copyright (C) 2025 Elasticsearch B.V. + */ + +package org.elasticsearch.index.codec.vectors; + +import org.elasticsearch.test.ESTestCase; + +/** + * copied and modified from Lucene + */ +public class NeighborQueueTests extends ESTestCase { + public void testNeighborsProduct() { + // make sure we have the sign correct + NeighborQueue nn = new NeighborQueue(2, false); + assertTrue(nn.insertWithOverflow(2, 0.5f)); + assertTrue(nn.insertWithOverflow(1, 0.2f)); + assertTrue(nn.insertWithOverflow(3, 1f)); + assertEquals(0.5f, nn.topScore(), 0); + nn.pop(); + assertEquals(1f, nn.topScore(), 0); + nn.pop(); + } + + public void testNeighborsMaxHeap() { + NeighborQueue nn = new NeighborQueue(2, true); + assertTrue(nn.insertWithOverflow(2, 2)); + assertTrue(nn.insertWithOverflow(1, 1)); + assertFalse(nn.insertWithOverflow(3, 3)); + assertEquals(2f, nn.topScore(), 0); + nn.pop(); + assertEquals(1f, nn.topScore(), 0); + } + + public void testTopMaxHeap() { + NeighborQueue nn = new NeighborQueue(2, true); + nn.add(1, 2); + nn.add(2, 1); + // lower scores are better; highest score on top + assertEquals(2, nn.topScore(), 0); + assertEquals(1, nn.topNode()); + } + + public void testTopMinHeap() { + NeighborQueue nn = new NeighborQueue(2, false); + nn.add(1, 0.5f); + nn.add(2, -0.5f); + // higher scores are better; lowest score on top + assertEquals(-0.5f, nn.topScore(), 0); + assertEquals(2, nn.topNode()); + } + + public void testClear() { + NeighborQueue nn = new NeighborQueue(2, false); + nn.add(1, 1.1f); + nn.add(2, -2.2f); + nn.clear(); + + assertEquals(0, nn.size()); + } + + public void testMaxSizeQueue() { + NeighborQueue nn = new NeighborQueue(2, false); + nn.add(1, 1); + nn.add(2, 2); + assertEquals(2, nn.size()); + assertEquals(1, nn.topNode()); + + // insertWithOverflow does not extend the queue + nn.insertWithOverflow(3, 3); + assertEquals(2, nn.size()); + assertEquals(2, nn.topNode()); + + // add does extend the queue beyond maxSize + nn.add(4, 1); + assertEquals(3, nn.size()); + } + + public void testUnboundedQueue() { + NeighborQueue nn = new NeighborQueue(1, true); + float maxScore = -2; + int maxNode = -1; + for (int i = 0; i < 256; i++) { + // initial size is 32 + float score = random().nextFloat(); + if (score > maxScore) { + maxScore = score; + maxNode = i; + } + nn.add(i, score); + } + assertEquals(maxScore, nn.topScore(), 0); + assertEquals(maxNode, nn.topNode()); + } + + public void testInvalidArguments() { + expectThrows(IllegalArgumentException.class, () -> new NeighborQueue(0, false)); + } + + public void testToString() { + assertEquals("Neighbors[0]", new NeighborQueue(2, false).toString()); + } + +}