diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index cd375474797be..a1c2d219b75d5 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -7,7 +7,6 @@ * License v3.0 only", or the "Server Side Public License, v 1". */ -import org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat; import org.elasticsearch.plugins.internal.RestExtension; import org.elasticsearch.reservedstate.ReservedStateHandlerProvider; @@ -463,7 +462,8 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat, org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat, - ES920DiskBBQVectorsFormat; + org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat, + org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat; provides org.apache.lucene.codecs.Codec with diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java index aa1e80a3b28ee..33ea8085191c3 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidAssignments.java @@ -9,9 +9,9 @@ package org.elasticsearch.index.codec.vectors.diskbbq; -record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { +public record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { - CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { + public CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { this(centroids.length, centroids, assignments, overspillAssignments); assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0 : "assignments and overspillAssignments must have the same length"; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidSupplier.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidSupplier.java index 9794508047c7f..08898e8844aca 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidSupplier.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/CentroidSupplier.java @@ -14,7 +14,7 @@ /** * An interface for that supply centroids. */ -interface CentroidSupplier { +public interface CentroidSupplier { CentroidSupplier EMPTY = new CentroidSupplier() { @Override public int size() { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java index 95a624f43b01e..040b1b31aa235 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DiskBBQBulkWriter.java @@ -20,7 +20,7 @@ * This class provides the structure for writing vectors in bulk, with specific * implementations for different bit sizes strategies. */ -abstract class DiskBBQBulkWriter { +public abstract class DiskBBQBulkWriter { protected final int bulkSize; protected final IndexOutput out; @@ -29,18 +29,18 @@ protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) { this.out = out; } - abstract void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer docsWriter) throws IOException; + public abstract void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer docsWriter) throws IOException; - static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { + public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; - OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + public OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { super(bulkSize, out); this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } @Override - void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer docsWriter) throws IOException { + public void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer docsWriter) throws IOException { int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { @@ -93,16 +93,16 @@ private void writeCorrection(OptimizedScalarQuantizer.QuantizationResult correct } } - static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter { + public static class SevenBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; - SevenBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + public SevenBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { super(bulkSize, out); this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } @Override - void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer docsWriter) throws IOException { + public void writeVectors(QuantizedVectorValues qvv, CheckedIntConsumer docsWriter) throws IOException { int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DocIdsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DocIdsWriter.java index 0c28d9ef0ba44..b536aa123df79 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DocIdsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/DocIdsWriter.java @@ -30,7 +30,7 @@ * *

It is copied from the BKD implementation. */ -final class DocIdsWriter { +public final class DocIdsWriter { private static final byte CONTINUOUS_IDS = (byte) -2; private static final byte DELTA_BPV_16 = (byte) 16; @@ -40,7 +40,7 @@ final class DocIdsWriter { private int[] scratch = new int[0]; - DocIdsWriter() {} + public DocIdsWriter() {} /** * Calculate the best encoding that will be used to write blocks of doc ids of blockSize. @@ -51,7 +51,7 @@ final class DocIdsWriter { * @param blockSize the block size * @return the byte encoding to use for the blocks */ - byte calculateBlockEncoding(IntToIntFunction docIds, int count, int blockSize) { + public byte calculateBlockEncoding(IntToIntFunction docIds, int count, int blockSize) { if (count == 0) { return CONTINUOUS_IDS; } @@ -90,7 +90,7 @@ byte calculateBlockEncoding(IntToIntFunction docIds, int count, int blockSize) { } } - void writeDocIds(IntToIntFunction docIds, int count, byte encoding, DataOutput out) throws IOException { + public void writeDocIds(IntToIntFunction docIds, int count, byte encoding, DataOutput out) throws IOException { if (count == 0) { return; } @@ -206,7 +206,7 @@ private static int[] sortedAndMaxAndMin2Max(IntToIntFunction docIds, int count) return new int[] { (strictlySorted && min2max == count) ? 1 : 0, max, min2max }; } - void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException { + public void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOException { if (count == 0) { return; } @@ -253,7 +253,7 @@ void writeDocIds(IntToIntFunction docIds, int count, DataOutput out) throws IOEx } } - void readInts(IndexInput in, int count, byte encoding, int[] docIDs) throws IOException { + public void readInts(IndexInput in, int count, byte encoding, int[] docIDs) throws IOException { if (count == 0) { return; } @@ -271,7 +271,7 @@ void readInts(IndexInput in, int count, byte encoding, int[] docIDs) throws IOEx } /** Read {@code count} integers into {@code docIDs}. */ - void readInts(IndexInput in, int count, int[] docIDs) throws IOException { + public void readInts(IndexInput in, int count, int[] docIDs) throws IOException { if (count == 0) { return; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java index 937aedd9236e5..c9b5c05a62a10 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java @@ -43,7 +43,8 @@ public ES920DiskBBQVectorsReader(SegmentReadState state, Map addField(FieldInfo fieldInfo) throws IOExc return rawVectorDelegate; } - abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + public abstract CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) throws IOException; - record CentroidOffsetAndLength(LongValues offsets, LongValues lengths) {} + public record CentroidOffsetAndLength(LongValues offsets, LongValues lengths) {} - abstract void writeCentroids( + public abstract void writeCentroids( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, float[] globalCentroid, @@ -140,7 +140,7 @@ abstract void writeCentroids( IndexOutput centroidOutput ) throws IOException; - abstract CentroidOffsetAndLength buildAndWritePostingsLists( + public abstract CentroidOffsetAndLength buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, @@ -150,7 +150,7 @@ abstract CentroidOffsetAndLength buildAndWritePostingsLists( int[] overspillAssignments ) throws IOException; - abstract CentroidOffsetAndLength buildAndWritePostingsLists( + public abstract CentroidOffsetAndLength buildAndWritePostingsLists( FieldInfo fieldInfo, CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, @@ -161,7 +161,7 @@ abstract CentroidOffsetAndLength buildAndWritePostingsLists( int[] overspillAssignments ) throws IOException; - abstract CentroidSupplier createCentroidSupplier( + public abstract CentroidSupplier createCentroidSupplier( IndexInput centroidsInput, int numCentroids, FieldInfo fieldInfo, diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntSorter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntSorter.java index cf8b9051e2347..fd5ccc39cfe73 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntSorter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntSorter.java @@ -12,12 +12,12 @@ import org.apache.lucene.util.IntroSorter; import org.apache.lucene.util.hnsw.IntToIntFunction; -class IntSorter extends IntroSorter { +public class IntSorter extends IntroSorter { int pivot = -1; private final int[] arr; private final IntToIntFunction func; - IntSorter(int[] arr, IntToIntFunction func) { + public IntSorter(int[] arr, IntToIntFunction func) { this.arr = arr; this.func = func; } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntToBooleanFunction.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntToBooleanFunction.java index f843d91b2ebcb..2286800a9ed55 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntToBooleanFunction.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/IntToBooleanFunction.java @@ -13,6 +13,6 @@ * Functional interface representing a function that takes an integer input * and produces a boolean output. */ -interface IntToBooleanFunction { +public interface IntToBooleanFunction { boolean apply(int value); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/QuantizedVectorValues.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/QuantizedVectorValues.java index af2ab9223093a..3afa799cd8755 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/QuantizedVectorValues.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/QuantizedVectorValues.java @@ -18,7 +18,7 @@ * Provides methods to iterate through the vectors and retrieve * associated quantization correction data. */ -interface QuantizedVectorValues { +public interface QuantizedVectorValues { int count(); byte[] next() throws IOException; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java new file mode 100644 index 0000000000000..abacdfc9e2a8a --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormat.java @@ -0,0 +1,143 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq.next; + +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.KnnVectorsWriter; +import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil; +import org.apache.lucene.codecs.hnsw.FlatVectorsFormat; +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.codecs.lucene99.Lucene99FlatVectorsFormat; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.SegmentWriteState; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; + +/** + * Codec format for Inverted File Vector indexes. This index expects to break the dimensional space + * into clusters and assign each vector to a cluster generating a posting list of vectors. Clusters + * are represented by centroids. + * The vector quantization format used here is a per-vector optimized scalar quantization. Also see {@link + * OptimizedScalarQuantizer}. Some of key features are: + * + * The format is stored in three files: + * + *

.cenivf (centroid data) file

+ *

Which stores the raw and quantized centroid vectors. + * + *

.clivf (cluster data) file

+ * + *

Stores the quantized vectors for each cluster, inline and stored in blocks. Additionally, the docIds of + * each vector is stored. + * + *

.mivf (centroid metadata) file

+ * + *

Stores metadata including the number of centroids and their offsets in the clivf file

+ * + */ +public class ESNextDiskBBQVectorsFormat extends KnnVectorsFormat { + + public static final String NAME = "ESNextDiskBBQVectorsFormat"; + // centroid ordinals -> centroid values, offsets + public static final String CENTROID_EXTENSION = "cenivf"; + // offsets contained in cen_ivf, [vector ordinals, actually just docIds](long varint), quantized + // vectors (OSQ bit) + public static final String CLUSTER_EXTENSION = "clivf"; + static final String IVF_META_EXTENSION = "mivf"; + + public static final int VERSION_START = 0; + public static final int VERSION_CURRENT = VERSION_START; + + private static final FlatVectorsFormat rawVectorFormat = new Lucene99FlatVectorsFormat( + FlatVectorScorerUtil.getLucene99FlatVectorsScorer() + ); + private static final Map supportedFormats = Map.of(rawVectorFormat.getName(), rawVectorFormat); + + // This dynamically sets the cluster probe based on the `k` requested and the number of clusters. + // useful when searching with 'efSearch' type parameters instead of requiring a specific ratio. + public static final float DYNAMIC_VISIT_RATIO = 0.0f; + public static final int DEFAULT_VECTORS_PER_CLUSTER = 384; + public static final int MIN_VECTORS_PER_CLUSTER = 64; + public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536 + public static final int DEFAULT_CENTROIDS_PER_PARENT_CLUSTER = 16; + public static final int MIN_CENTROIDS_PER_PARENT_CLUSTER = 2; + public static final int MAX_CENTROIDS_PER_PARENT_CLUSTER = 1 << 8; // 256 + + private final int vectorPerCluster; + private final int centroidsPerParentCluster; + + public ESNextDiskBBQVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { + super(NAME); + if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { + throw new IllegalArgumentException( + "vectorsPerCluster must be between " + + MIN_VECTORS_PER_CLUSTER + + " and " + + MAX_VECTORS_PER_CLUSTER + + ", got: " + + vectorPerCluster + ); + } + if (centroidsPerParentCluster < MIN_CENTROIDS_PER_PARENT_CLUSTER || centroidsPerParentCluster > MAX_CENTROIDS_PER_PARENT_CLUSTER) { + throw new IllegalArgumentException( + "centroidsPerParentCluster must be between " + + MIN_CENTROIDS_PER_PARENT_CLUSTER + + " and " + + MAX_CENTROIDS_PER_PARENT_CLUSTER + + ", got: " + + centroidsPerParentCluster + ); + } + this.vectorPerCluster = vectorPerCluster; + this.centroidsPerParentCluster = centroidsPerParentCluster; + } + + /** Constructs a format using the given graph construction parameters and scalar quantization. */ + public ESNextDiskBBQVectorsFormat() { + this(DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER); + } + + @Override + public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { + return new ESNextDiskBBQVectorsWriter( + rawVectorFormat.getName(), + state, + rawVectorFormat.fieldsWriter(state), + vectorPerCluster, + centroidsPerParentCluster + ); + } + + @Override + public KnnVectorsReader fieldsReader(SegmentReadState state) throws IOException { + Map readers = Maps.newHashMapWithExpectedSize(supportedFormats.size()); + for (var fe : supportedFormats.entrySet()) { + readers.put(fe.getKey(), fe.getValue().fieldsReader(state)); + } + + return new ESNextDiskBBQVectorsReader(state, Collections.unmodifiableMap(readers)); + } + + @Override + public int getMaxDimensions(String fieldName) { + return 4096; + } + + @Override + public String toString() { + return "ESNextDiskBBQVectorsFormat(" + "vectorPerCluster=" + vectorPerCluster + ')'; + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java new file mode 100644 index 0000000000000..a6b50468e9640 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsReader.java @@ -0,0 +1,593 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq.next; + +import org.apache.lucene.codecs.hnsw.FlatVectorsReader; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.SegmentReadState; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.KnnCollector; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.VectorUtil; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.cluster.NeighborQueue; +import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter; +import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsReader; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; +import org.elasticsearch.simdvec.ESVectorUtil; + +import java.io.IOException; +import java.util.Map; + +import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; +import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; +import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte; +import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; +import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA; +import static org.elasticsearch.simdvec.ES91OSQVectorsScorer.BULK_SIZE; + +/** + * Default implementation of {@link IVFVectorsReader}. It scores the posting lists centroids using + * brute force and then scores the top ones using the posting list. + */ +public class ESNextDiskBBQVectorsReader extends IVFVectorsReader { + + public ESNextDiskBBQVectorsReader(SegmentReadState state, Map rawVectorsReader) throws IOException { + super(state, rawVectorsReader); + } + + CentroidIterator getPostingListPrefetchIterator(CentroidIterator centroidIterator, IndexInput postingListSlice) throws IOException { + return new CentroidIterator() { + CentroidOffsetAndLength nextOffsetAndLength = centroidIterator.hasNext() + ? centroidIterator.nextPostingListOffsetAndLength() + : null; + + { + // prefetch the first one + if (nextOffsetAndLength != null) { + prefetch(nextOffsetAndLength); + } + } + + void prefetch(CentroidOffsetAndLength offsetAndLength) throws IOException { + postingListSlice.prefetch(offsetAndLength.offset(), offsetAndLength.length()); + } + + @Override + public boolean hasNext() { + return nextOffsetAndLength != null; + } + + @Override + public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException { + CentroidOffsetAndLength offsetAndLength = nextOffsetAndLength; + if (centroidIterator.hasNext()) { + nextOffsetAndLength = centroidIterator.nextPostingListOffsetAndLength(); + prefetch(nextOffsetAndLength); + } else { + nextOffsetAndLength = null; // indicate we reached the end + } + return offsetAndLength; + } + }; + } + + @Override + public CentroidIterator getCentroidIterator( + FieldInfo fieldInfo, + int numCentroids, + IndexInput centroids, + float[] targetQuery, + IndexInput postingListSlice, + float visitRatio + ) throws IOException { + final FieldEntry fieldEntry = fields.get(fieldInfo.number); + final float globalCentroidDp = fieldEntry.globalCentroidDp(); + final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + final int[] scratch = new int[targetQuery.length]; + final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( + targetQuery, + new float[targetQuery.length], + scratch, + (byte) 7, + fieldEntry.globalCentroid() + ); + final byte[] quantized = new byte[targetQuery.length]; + for (int i = 0; i < quantized.length; i++) { + quantized[i] = (byte) scratch[i]; + } + final ES92Int7VectorsScorer scorer = ESVectorUtil.getES92Int7VectorsScorer(centroids, fieldInfo.getVectorDimension()); + centroids.seek(0L); + int numParents = centroids.readVInt(); + + CentroidIterator centroidIterator; + if (numParents > 0) { + // equivalent to (float) centroidsPerParentCluster / 2 + float centroidOversampling = (float) fieldEntry.numCentroids() / (2 * numParents); + centroidIterator = getCentroidIteratorWithParents( + fieldInfo, + centroids, + numParents, + numCentroids, + scorer, + quantized, + queryParams, + globalCentroidDp, + visitRatio * centroidOversampling + ); + } else { + centroidIterator = getCentroidIteratorNoParent( + fieldInfo, + centroids, + numCentroids, + scorer, + quantized, + queryParams, + globalCentroidDp + ); + } + return getPostingListPrefetchIterator(centroidIterator, postingListSlice); + } + + private static CentroidIterator getCentroidIteratorNoParent( + FieldInfo fieldInfo, + IndexInput centroids, + int numCentroids, + ES92Int7VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryParams, + float globalCentroidDp + ) throws IOException { + final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true); + score( + neighborQueue, + numCentroids, + 0, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + new float[ES92Int7VectorsScorer.BULK_SIZE] + ); + long offset = centroids.getFilePointer(); + return new CentroidIterator() { + @Override + public boolean hasNext() { + return neighborQueue.size() > 0; + } + + @Override + public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException { + int centroidOrdinal = neighborQueue.pop(); + centroids.seek(offset + (long) Long.BYTES * 2 * centroidOrdinal); + long postingListOffset = centroids.readLong(); + long postingListLength = centroids.readLong(); + return new CentroidOffsetAndLength(postingListOffset, postingListLength); + } + }; + } + + private static CentroidIterator getCentroidIteratorWithParents( + FieldInfo fieldInfo, + IndexInput centroids, + int numParents, + int numCentroids, + ES92Int7VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryParams, + float globalCentroidDp, + float centroidRatio + ) throws IOException { + // build the three queues we are going to use + final NeighborQueue parentsQueue = new NeighborQueue(numParents, true); + final int maxChildrenSize = centroids.readVInt(); + final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true); + final int bufferSize = (int) Math.min(Math.max(centroidRatio * numCentroids, 1), numCentroids); + final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true); + // score the parents + final float[] scores = new float[ES92Int7VectorsScorer.BULK_SIZE]; + score( + parentsQueue, + numParents, + 0, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + scores + ); + final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Integer.BYTES; + final long offset = centroids.getFilePointer(); + final long childrenOffset = offset + (long) Long.BYTES * numParents; + // populate the children's queue by reading parents one by one + while (parentsQueue.size() > 0 && neighborQueue.size() < bufferSize) { + final int pop = parentsQueue.pop(); + populateOneChildrenGroup( + currentParentQueue, + centroids, + offset + 2L * Integer.BYTES * pop, + childrenOffset, + centroidQuantizeSize, + fieldInfo, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + scores + ); + while (currentParentQueue.size() > 0 && neighborQueue.size() < bufferSize) { + final float score = currentParentQueue.topScore(); + final int children = currentParentQueue.pop(); + neighborQueue.add(children, score); + } + } + final long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids; + return new CentroidIterator() { + @Override + public boolean hasNext() { + return neighborQueue.size() > 0; + } + + @Override + public CentroidOffsetAndLength nextPostingListOffsetAndLength() throws IOException { + int centroidOrdinal = nextCentroid(); + centroids.seek(childrenFileOffsets + (long) Long.BYTES * 2 * centroidOrdinal); + long postingListOffset = centroids.readLong(); + long postingListLength = centroids.readLong(); + return new CentroidOffsetAndLength(postingListOffset, postingListLength); + } + + private int nextCentroid() throws IOException { + if (currentParentQueue.size() > 0) { + // return next centroid and maybe add a children from the current parent queue + return neighborQueue.popAndAddRaw(currentParentQueue.popRaw()); + } else if (parentsQueue.size() > 0) { + // current parent queue is empty, populate it again with the next parent + int pop = parentsQueue.pop(); + populateOneChildrenGroup( + currentParentQueue, + centroids, + offset + 2L * Integer.BYTES * pop, + childrenOffset, + centroidQuantizeSize, + fieldInfo, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + scores + ); + return nextCentroid(); + } else { + return neighborQueue.pop(); + } + } + }; + } + + private static void populateOneChildrenGroup( + NeighborQueue neighborQueue, + IndexInput centroids, + long parentOffset, + long childrenOffset, + long centroidQuantizeSize, + FieldInfo fieldInfo, + ES92Int7VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryParams, + float globalCentroidDp, + float[] scores + ) throws IOException { + centroids.seek(parentOffset); + int childrenOrdinal = centroids.readInt(); + int numChildren = centroids.readInt(); + centroids.seek(childrenOffset + centroidQuantizeSize * childrenOrdinal); + score( + neighborQueue, + numChildren, + childrenOrdinal, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + scores + ); + } + + private static void score( + NeighborQueue neighborQueue, + int size, + int scoresOffset, + ES92Int7VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryCorrections, + float centroidDp, + VectorSimilarityFunction similarityFunction, + float[] scores + ) throws IOException { + int limit = size - ES92Int7VectorsScorer.BULK_SIZE + 1; + int i = 0; + for (; i < limit; i += ES92Int7VectorsScorer.BULK_SIZE) { + scorer.scoreBulk( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scores + ); + for (int j = 0; j < ES92Int7VectorsScorer.BULK_SIZE; j++) { + neighborQueue.add(scoresOffset + i + j, scores[j]); + } + } + + for (; i < size; i++) { + float score = scorer.score( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); + neighborQueue.add(scoresOffset + i, score); + } + } + + @Override + public PostingVisitor getPostingVisitor(FieldInfo fieldInfo, IndexInput indexInput, float[] target, Bits acceptDocs) + throws IOException { + FieldEntry entry = fields.get(fieldInfo.number); + final int maxPostingListSize = indexInput.readVInt(); + return new MemorySegmentPostingsVisitor(target, indexInput, entry, fieldInfo, maxPostingListSize, acceptDocs); + } + + @Override + public Map getOffHeapByteSize(FieldInfo fieldInfo) { + return Map.of(); + } + + private static class MemorySegmentPostingsVisitor implements PostingVisitor { + final long quantizedByteLength; + final IndexInput indexInput; + final float[] target; + final FieldEntry entry; + final FieldInfo fieldInfo; + final Bits acceptDocs; + private final ES91OSQVectorsScorer osqVectorsScorer; + final float[] scores = new float[BULK_SIZE]; + final float[] correctionsLower = new float[BULK_SIZE]; + final float[] correctionsUpper = new float[BULK_SIZE]; + final int[] correctionsSum = new int[BULK_SIZE]; + final float[] correctionsAdd = new float[BULK_SIZE]; + final int[] docIdsScratch = new int[BULK_SIZE]; + byte docEncoding; + int docBase = 0; + + int vectors; + boolean quantized = false; + float centroidDp; + final float[] centroid; + long slicePos; + OptimizedScalarQuantizer.QuantizationResult queryCorrections; + + final float[] scratch; + final int[] quantizationScratch; + final byte[] quantizedQueryScratch; + final OptimizedScalarQuantizer quantizer; + final DocIdsWriter idsWriter = new DocIdsWriter(); + final float[] correctiveValues = new float[3]; + final long quantizedVectorByteSize; + + MemorySegmentPostingsVisitor( + float[] target, + IndexInput indexInput, + FieldEntry entry, + FieldInfo fieldInfo, + int maxPostingListSize, + Bits acceptDocs + ) throws IOException { + this.target = target; + this.indexInput = indexInput; + this.entry = entry; + this.fieldInfo = fieldInfo; + this.acceptDocs = acceptDocs; + centroid = new float[fieldInfo.getVectorDimension()]; + scratch = new float[target.length]; + quantizationScratch = new int[target.length]; + final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64); + quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8]; + quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES; + quantizedVectorByteSize = (discretizedDimensions / 8); + quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction(), DEFAULT_LAMBDA, 1); + osqVectorsScorer = ESVectorUtil.getES91OSQVectorsScorer(indexInput, fieldInfo.getVectorDimension()); + } + + @Override + public int resetPostingsScorer(long offset) throws IOException { + quantized = false; + indexInput.seek(offset); + indexInput.readFloats(centroid, 0, centroid.length); + centroidDp = Float.intBitsToFloat(indexInput.readInt()); + vectors = indexInput.readVInt(); + docEncoding = indexInput.readByte(); + docBase = 0; + slicePos = indexInput.getFilePointer(); + return vectors; + } + + private float scoreIndividually() throws IOException { + float maxScore = Float.NEGATIVE_INFINITY; + // score individually, first the quantized byte chunk + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j]; + if (doc != -1) { + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + scores[j] = qcDist; + } else { + indexInput.skipBytes(quantizedVectorByteSize); + } + } + // read in all corrections + indexInput.readFloats(correctionsLower, 0, BULK_SIZE); + indexInput.readFloats(correctionsUpper, 0, BULK_SIZE); + for (int j = 0; j < BULK_SIZE; j++) { + correctionsSum[j] = Short.toUnsignedInt(indexInput.readShort()); + } + indexInput.readFloats(correctionsAdd, 0, BULK_SIZE); + // Now apply corrections + for (int j = 0; j < BULK_SIZE; j++) { + int doc = docIdsScratch[j]; + if (doc != -1) { + scores[j] = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctionsLower[j], + correctionsUpper[j], + correctionsSum[j], + correctionsAdd[j], + scores[j] + ); + if (scores[j] > maxScore) { + maxScore = scores[j]; + } + } + } + return maxScore; + } + + private static int docToBulkScore(int[] docIds, Bits acceptDocs) { + assert acceptDocs != null : "acceptDocs must not be null"; + int docToScore = ES91OSQVectorsScorer.BULK_SIZE; + for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { + if (acceptDocs.get(docIds[i]) == false) { + docIds[i] = -1; + docToScore--; + } + } + return docToScore; + } + + private void collectBulk(KnnCollector knnCollector, float[] scores) { + for (int i = 0; i < ES91OSQVectorsScorer.BULK_SIZE; i++) { + final int doc = docIdsScratch[i]; + if (doc != -1) { + knnCollector.collect(doc, scores[i]); + } + } + } + + private void readDocIds(int count) throws IOException { + idsWriter.readInts(indexInput, count, docEncoding, docIdsScratch); + // reconstitute from the deltas + for (int j = 0; j < count; j++) { + docBase += docIdsScratch[j]; + docIdsScratch[j] = docBase; + } + } + + @Override + public int visit(KnnCollector knnCollector) throws IOException { + indexInput.seek(slicePos); + // block processing + int scoredDocs = 0; + int limit = vectors - BULK_SIZE + 1; + int i = 0; + // read Docs + for (; i < limit; i += BULK_SIZE) { + // read the doc ids + readDocIds(BULK_SIZE); + final int docsToBulkScore = acceptDocs == null ? BULK_SIZE : docToBulkScore(docIdsScratch, acceptDocs); + if (docsToBulkScore == 0) { + indexInput.skipBytes(quantizedByteLength * BULK_SIZE); + continue; + } + quantizeQueryIfNecessary(); + final float maxScore; + if (docsToBulkScore < BULK_SIZE / 2) { + maxScore = scoreIndividually(); + } else { + maxScore = osqVectorsScorer.scoreBulk( + quantizedQueryScratch, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + scores + ); + } + if (knnCollector.minCompetitiveSimilarity() < maxScore) { + collectBulk(knnCollector, scores); + } + scoredDocs += docsToBulkScore; + } + // process tail + // read the doc ids + if (i < vectors) { + readDocIds(vectors - i); + } + int count = 0; + for (; i < vectors; i++) { + int doc = docIdsScratch[count++]; + if (acceptDocs == null || acceptDocs.get(doc)) { + quantizeQueryIfNecessary(); + float qcDist = osqVectorsScorer.quantizeScore(quantizedQueryScratch); + indexInput.readFloats(correctiveValues, 0, 3); + final int quantizedComponentSum = Short.toUnsignedInt(indexInput.readShort()); + float score = osqVectorsScorer.score( + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + fieldInfo.getVectorSimilarityFunction(), + centroidDp, + correctiveValues[0], + correctiveValues[1], + quantizedComponentSum, + correctiveValues[2], + qcDist + ); + scoredDocs++; + knnCollector.collect(doc, score); + } else { + indexInput.skipBytes(quantizedByteLength); + } + } + if (scoredDocs > 0) { + knnCollector.incVisitedCount(scoredDocs); + } + return scoredDocs; + } + + private void quantizeQueryIfNecessary() { + if (quantized == false) { + assert fieldInfo.getVectorSimilarityFunction() != COSINE || VectorUtil.isUnitVector(target); + queryCorrections = quantizer.scalarQuantize(target, scratch, quantizationScratch, (byte) 4, centroid); + transposeHalfByte(quantizationScratch, quantizedQueryScratch); + quantized = true; + } + } + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java new file mode 100644 index 0000000000000..6799743d9dc9f --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsWriter.java @@ -0,0 +1,765 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.index.codec.vectors.diskbbq.next; + +import org.apache.lucene.codecs.hnsw.FlatVectorsWriter; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FloatVectorValues; +import org.apache.lucene.index.MergeState; +import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; +import org.apache.lucene.store.IndexOutput; +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.IntToIntFunction; +import org.apache.lucene.util.packed.PackedInts; +import org.apache.lucene.util.packed.PackedLongValues; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.index.codec.vectors.BQVectorUtils; +import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer; +import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; +import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; +import org.elasticsearch.index.codec.vectors.diskbbq.CentroidAssignments; +import org.elasticsearch.index.codec.vectors.diskbbq.CentroidSupplier; +import org.elasticsearch.index.codec.vectors.diskbbq.DiskBBQBulkWriter; +import org.elasticsearch.index.codec.vectors.diskbbq.DocIdsWriter; +import org.elasticsearch.index.codec.vectors.diskbbq.IVFVectorsWriter; +import org.elasticsearch.index.codec.vectors.diskbbq.IntSorter; +import org.elasticsearch.index.codec.vectors.diskbbq.IntToBooleanFunction; +import org.elasticsearch.index.codec.vectors.diskbbq.QuantizedVectorValues; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.simdvec.ES91OSQVectorsScorer; +import org.elasticsearch.simdvec.ES92Int7VectorsScorer; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.AbstractList; +import java.util.Arrays; + +import static org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans.NO_SOAR_ASSIGNMENT; + +/** + * Default implementation of {@link IVFVectorsWriter}. It uses {@link HierarchicalKMeans} algorithm to + * partition the vector space, and then stores the centroids and posting list in a sequential + * fashion. + */ +public class ESNextDiskBBQVectorsWriter extends IVFVectorsWriter { + private static final Logger logger = LogManager.getLogger(ESNextDiskBBQVectorsWriter.class); + + private final int vectorPerCluster; + private final int centroidsPerParentCluster; + + public ESNextDiskBBQVectorsWriter( + String rawVectorFormatName, + SegmentWriteState state, + FlatVectorsWriter rawVectorDelegate, + int vectorPerCluster, + int centroidsPerParentCluster + ) throws IOException { + super(state, rawVectorFormatName, rawVectorDelegate); + this.vectorPerCluster = vectorPerCluster; + this.centroidsPerParentCluster = centroidsPerParentCluster; + } + + @Override + public CentroidOffsetAndLength buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + long fileOffset, + int[] assignments, + int[] overspillAssignments + ) throws IOException { + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != NO_SOAR_ASSIGNMENT) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int maxPostingListSize = 0; + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + int size = centroidVectorCount[c]; + maxPostingListSize = Math.max(maxPostingListSize, size); + assignmentsByCluster[c] = new int[size]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != NO_SOAR_ASSIGNMENT) { + assignmentsByCluster[s][centroidVectorCount[s]++] = i; + } + } + } + // write the max posting list size + postingsOutput.writeVInt(maxPostingListSize); + // write the posting lists + final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( + floatVectorValues, + fieldInfo.getVectorDimension(), + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) + ); + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + final int[] docIds = new int[maxPostingListSize]; + final int[] docDeltas = new int[maxPostingListSize]; + final int[] clusterOrds = new int[maxPostingListSize]; + DocIdsWriter idsWriter = new DocIdsWriter(); + for (int c = 0; c < centroidSupplier.size(); c++) { + float[] centroid = centroidSupplier.centroid(c); + int[] cluster = assignmentsByCluster[c]; + long offset = postingsOutput.alignFilePointer(Float.BYTES) - fileOffset; + offsets.add(offset); + buffer.asFloatBuffer().put(centroid); + // write raw centroid for quantizing the query vectors + postingsOutput.writeBytes(buffer.array(), buffer.array().length); + // write centroid dot product for quantizing the query vectors + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + int size = cluster.length; + // write docIds + postingsOutput.writeVInt(size); + for (int j = 0; j < size; j++) { + docIds[j] = floatVectorValues.ordToDoc(cluster[j]); + clusterOrds[j] = j; + } + // sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds + new IntSorter(clusterOrds, i -> docIds[i]).sort(0, size); + // encode doc deltas + for (int j = 0; j < size; j++) { + docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]]; + } + onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[clusterOrds[ord]]); + byte encoding = idsWriter.calculateBlockEncoding(i -> docDeltas[i], size, ES91OSQVectorsScorer.BULK_SIZE); + postingsOutput.writeByte(encoding); + bulkWriter.writeVectors(onHeapQuantizedVectors, i -> { + // for vector i we write `bulk` size docs or the remaining docs + idsWriter.writeDocIds(d -> docDeltas[i + d], Math.min(ES91OSQVectorsScorer.BULK_SIZE, size - i), encoding, postingsOutput); + }); + lengths.add(postingsOutput.getFilePointer() - fileOffset - offset); + } + + if (logger.isDebugEnabled()) { + printClusterQualityStatistics(assignmentsByCluster); + } + + return new CentroidOffsetAndLength(offsets.build(), lengths.build()); + } + + @Override + @SuppressForbidden(reason = "require usage of Lucene's IOUtils#deleteFilesIgnoringExceptions(...)") + public CentroidOffsetAndLength buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + long fileOffset, + MergeState mergeState, + int[] assignments, + int[] overspillAssignments + ) throws IOException { + // first, quantize all the vectors into a temporary file + String quantizedVectorsTempName = null; + boolean success = false; + try ( + IndexOutput quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput( + mergeState.segmentInfo.name, + "qvec_", + IOContext.DEFAULT + ) + ) { + quantizedVectorsTempName = quantizedVectorsTemp.getName(); + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + int[] quantized = new int[fieldInfo.getVectorDimension()]; + byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; + float[] scratch = new float[fieldInfo.getVectorDimension()]; + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + float[] centroid = centroidSupplier.centroid(c); + float[] vector = floatVectorValues.vectorValue(i); + boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != NO_SOAR_ASSIGNMENT; + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize( + vector, + scratch, + quantized, + (byte) 1, + centroid + ); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + if (overspill) { + int s = overspillAssignments[i]; + // write the overspill vector as well + result = quantizer.scalarQuantize(vector, scratch, quantized, (byte) 1, centroidSupplier.centroid(s)); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + } else { + // write a zero vector for the overspill + Arrays.fill(binary, (byte) 0); + OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0); + writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult); + } + } + success = true; + } finally { + if (success == false && quantizedVectorsTempName != null) { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, quantizedVectorsTempName); + } + } + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != NO_SOAR_ASSIGNMENT) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int maxPostingListSize = 0; + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + int size = centroidVectorCount[c]; + maxPostingListSize = Math.max(maxPostingListSize, size); + assignmentsByCluster[c] = new int[size]; + isOverspillByCluster[c] = new boolean[size]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != NO_SOAR_ASSIGNMENT) { + assignmentsByCluster[s][centroidVectorCount[s]] = i; + isOverspillByCluster[s][centroidVectorCount[s]++] = true; + } + } + } + // now we can read the quantized vectors from the temporary file + try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { + final PackedLongValues.Builder offsets = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + final PackedLongValues.Builder lengths = PackedLongValues.monotonicBuilder(PackedInts.COMPACT); + OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( + quantizedVectorsInput, + fieldInfo.getVectorDimension() + ); + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + // write the max posting list size + postingsOutput.writeVInt(maxPostingListSize); + // write the posting lists + final int[] docIds = new int[maxPostingListSize]; + final int[] docDeltas = new int[maxPostingListSize]; + final int[] clusterOrds = new int[maxPostingListSize]; + DocIdsWriter idsWriter = new DocIdsWriter(); + for (int c = 0; c < centroidSupplier.size(); c++) { + float[] centroid = centroidSupplier.centroid(c); + int[] cluster = assignmentsByCluster[c]; + boolean[] isOverspill = isOverspillByCluster[c]; + long offset = postingsOutput.alignFilePointer(Float.BYTES) - fileOffset; + offsets.add(offset); + // write raw centroid for quantizing the query vectors + buffer.asFloatBuffer().put(centroid); + postingsOutput.writeBytes(buffer.array(), buffer.array().length); + // write centroid dot product for quantizing the query vectors + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // write docIds + int size = cluster.length; + postingsOutput.writeVInt(size); + for (int j = 0; j < size; j++) { + docIds[j] = floatVectorValues.ordToDoc(cluster[j]); + clusterOrds[j] = j; + } + // sort cluster.buffer by docIds values, this way cluster ordinals are sorted by docIds + new IntSorter(clusterOrds, i -> docIds[i]).sort(0, size); + // encode doc deltas + for (int j = 0; j < size; j++) { + docDeltas[j] = j == 0 ? docIds[clusterOrds[j]] : docIds[clusterOrds[j]] - docIds[clusterOrds[j - 1]]; + } + byte encoding = idsWriter.calculateBlockEncoding(i -> docDeltas[i], size, ES91OSQVectorsScorer.BULK_SIZE); + postingsOutput.writeByte(encoding); + offHeapQuantizedVectors.reset(size, ord -> isOverspill[clusterOrds[ord]], ord -> cluster[clusterOrds[ord]]); + // write vectors + bulkWriter.writeVectors(offHeapQuantizedVectors, i -> { + // for vector i we write `bulk` size docs or the remaining docs + idsWriter.writeDocIds( + d -> docDeltas[d + i], + Math.min(ES91OSQVectorsScorer.BULK_SIZE, size - i), + encoding, + postingsOutput + ); + }); + lengths.add(postingsOutput.getFilePointer() - fileOffset - offset); + } + + if (logger.isDebugEnabled()) { + printClusterQualityStatistics(assignmentsByCluster); + } + return new CentroidOffsetAndLength(offsets.build(), lengths.build()); + } finally { + org.apache.lucene.util.IOUtils.deleteFilesIgnoringExceptions(mergeState.segmentInfo.dir, quantizedVectorsTempName); + } + } + + private static void printClusterQualityStatistics(int[][] clusters) { + float min = Float.MAX_VALUE; + float max = Float.MIN_VALUE; + float mean = 0; + float m2 = 0; + // iteratively compute the variance & mean + int count = 0; + for (int[] cluster : clusters) { + count += 1; + if (cluster == null) { + continue; + } + float delta = cluster.length - mean; + mean += delta / count; + m2 += delta * (cluster.length - mean); + min = Math.min(min, cluster.length); + max = Math.max(max, cluster.length); + } + float variance = m2 / (clusters.length - 1); + logger.debug( + "Centroid count: {} min: {} max: {} mean: {} stdDev: {} variance: {}", + clusters.length, + min, + max, + mean, + Math.sqrt(variance), + variance + ); + } + + @Override + public CentroidSupplier createCentroidSupplier( + IndexInput centroidsInput, + int numCentroids, + FieldInfo fieldInfo, + float[] globalCentroid + ) { + return new OffHeapCentroidSupplier(centroidsInput, numCentroids, fieldInfo); + } + + @Override + public void writeCentroids( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + CentroidOffsetAndLength centroidOffsetAndLength, + IndexOutput centroidOutput + ) throws IOException { + // TODO do we want to store these distances as well for future use? + // TODO: sort centroids by global centroid (was doing so previously here) + // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned + if (centroidSupplier.size() > centroidsPerParentCluster * centroidsPerParentCluster) { + writeCentroidsWithParents(fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, centroidOutput); + } else { + writeCentroidsWithoutParents(fieldInfo, centroidSupplier, globalCentroid, centroidOffsetAndLength, centroidOutput); + } + } + + private void writeCentroidsWithParents( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + CentroidOffsetAndLength centroidOffsetAndLength, + IndexOutput centroidOutput + ) throws IOException { + DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter( + ES92Int7VectorsScorer.BULK_SIZE, + centroidOutput + ); + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + final CentroidGroups centroidGroups = buildCentroidGroups(fieldInfo, centroidSupplier); + centroidOutput.writeVInt(centroidGroups.centroids.length); + centroidOutput.writeVInt(centroidGroups.maxVectorsPerCentroidLength); + QuantizedCentroids parentQuantizeCentroid = new QuantizedCentroids( + CentroidSupplier.fromArray(centroidGroups.centroids), + fieldInfo.getVectorDimension(), + osq, + globalCentroid + ); + bulkWriter.writeVectors(parentQuantizeCentroid, null); + int offset = 0; + for (int i = 0; i < centroidGroups.centroids().length; i++) { + centroidOutput.writeInt(offset); + centroidOutput.writeInt(centroidGroups.vectors()[i].length); + offset += centroidGroups.vectors()[i].length; + } + + QuantizedCentroids childrenQuantizeCentroid = new QuantizedCentroids( + centroidSupplier, + fieldInfo.getVectorDimension(), + osq, + globalCentroid + ); + for (int i = 0; i < centroidGroups.centroids().length; i++) { + final int[] centroidAssignments = centroidGroups.vectors()[i]; + childrenQuantizeCentroid.reset(idx -> centroidAssignments[idx], centroidAssignments.length); + bulkWriter.writeVectors(childrenQuantizeCentroid, null); + } + // write the centroid offsets at the end of the file + for (int i = 0; i < centroidGroups.centroids().length; i++) { + final int[] centroidAssignments = centroidGroups.vectors()[i]; + for (int assignment : centroidAssignments) { + centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(assignment)); + centroidOutput.writeLong(centroidOffsetAndLength.lengths().get(assignment)); + } + } + } + + private void writeCentroidsWithoutParents( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + CentroidOffsetAndLength centroidOffsetAndLength, + IndexOutput centroidOutput + ) throws IOException { + centroidOutput.writeVInt(0); + DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.SevenBitDiskBBQBulkWriter( + ES92Int7VectorsScorer.BULK_SIZE, + centroidOutput + ); + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + QuantizedCentroids quantizedCentroids = new QuantizedCentroids( + centroidSupplier, + fieldInfo.getVectorDimension(), + osq, + globalCentroid + ); + bulkWriter.writeVectors(quantizedCentroids, null); + // write the centroid offsets at the end of the file + for (int i = 0; i < centroidSupplier.size(); i++) { + centroidOutput.writeLong(centroidOffsetAndLength.offsets().get(i)); + centroidOutput.writeLong(centroidOffsetAndLength.lengths().get(i)); + } + } + + private record CentroidGroups(float[][] centroids, int[][] vectors, int maxVectorsPerCentroidLength) {} + + private CentroidGroups buildCentroidGroups(FieldInfo fieldInfo, CentroidSupplier centroidSupplier) throws IOException { + final FloatVectorValues floatVectorValues = FloatVectorValues.fromFloats(new AbstractList<>() { + @Override + public float[] get(int index) { + try { + return centroidSupplier.centroid(index); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public int size() { + return centroidSupplier.size(); + } + }, fieldInfo.getVectorDimension()); + // we use the HierarchicalKMeans to partition the space of all vectors across merging segments + // this are small numbers so we run it wih all the centroids. + final KMeansResult kMeansResult = new HierarchicalKMeans( + fieldInfo.getVectorDimension(), + HierarchicalKMeans.MAX_ITERATIONS_DEFAULT, + HierarchicalKMeans.SAMPLES_PER_CLUSTER_DEFAULT, + HierarchicalKMeans.MAXK, + -1 // disable SOAR assignments + ).cluster(floatVectorValues, centroidsPerParentCluster); + final int[] centroidVectorCount = new int[kMeansResult.centroids().length]; + for (int i = 0; i < kMeansResult.assignments().length; i++) { + centroidVectorCount[kMeansResult.assignments()[i]]++; + } + final int[][] vectorsPerCentroid = new int[kMeansResult.centroids().length][]; + int maxVectorsPerCentroidLength = 0; + for (int i = 0; i < kMeansResult.centroids().length; i++) { + vectorsPerCentroid[i] = new int[centroidVectorCount[i]]; + maxVectorsPerCentroidLength = Math.max(maxVectorsPerCentroidLength, centroidVectorCount[i]); + } + Arrays.fill(centroidVectorCount, 0); + for (int i = 0; i < kMeansResult.assignments().length; i++) { + final int c = kMeansResult.assignments()[i]; + vectorsPerCentroid[c][centroidVectorCount[c]++] = i; + } + return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength); + } + + /** + * Calculate the centroids for the given field. + * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments + * + * @param fieldInfo merging field info + * @param floatVectorValues the float vector values to merge + * @param globalCentroid the global centroid, calculated by this method and used to quantize the centroids + * @return the vector assignments, soar assignments, and if asked the centroids themselves that were computed + * @throws IOException if an I/O error occurs + */ + @Override + public CentroidAssignments calculateCentroids(FieldInfo fieldInfo, FloatVectorValues floatVectorValues, float[] globalCentroid) + throws IOException { + + long nanoTime = System.nanoTime(); + + // TODO: consider hinting / bootstrapping hierarchical kmeans with the prior segments centroids + CentroidAssignments centroidAssignments = buildCentroidAssignments(floatVectorValues, vectorPerCluster); + float[][] centroids = centroidAssignments.centroids(); + // TODO: for flush we are doing this over the vectors and here centroids which seems duplicative + // preliminary tests suggest recall is good using only centroids but need to do further evaluation + // TODO: push this logic into vector util? + for (float[] centroid : centroids) { + for (int j = 0; j < centroid.length; j++) { + globalCentroid[j] += centroid[j]; + } + } + for (int j = 0; j < globalCentroid.length; j++) { + globalCentroid[j] /= centroids.length; + } + + if (logger.isDebugEnabled()) { + logger.debug("calculate centroids and assign vectors time ms: {}", (System.nanoTime() - nanoTime) / 1000000.0); + logger.debug("final centroid count: {}", centroids.length); + } + return centroidAssignments; + } + + static CentroidAssignments buildCentroidAssignments(FloatVectorValues floatVectorValues, int vectorPerCluster) throws IOException { + KMeansResult kMeansResult = new HierarchicalKMeans(floatVectorValues.dimension()).cluster(floatVectorValues, vectorPerCluster); + float[][] centroids = kMeansResult.centroids(); + int[] assignments = kMeansResult.assignments(); + int[] soarAssignments = kMeansResult.soarAssignments(); + return new CentroidAssignments(centroids, assignments, soarAssignments); + } + + static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) + throws IOException { + indexOutput.writeBytes(binaryValue, binaryValue.length); + indexOutput.writeInt(Float.floatToIntBits(corrections.lowerInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.upperInterval())); + indexOutput.writeInt(Float.floatToIntBits(corrections.additionalCorrection())); + assert corrections.quantizedComponentSum() >= 0 && corrections.quantizedComponentSum() <= 0xffff; + indexOutput.writeShort((short) corrections.quantizedComponentSum()); + } + + static class OffHeapCentroidSupplier implements CentroidSupplier { + private final IndexInput centroidsInput; + private final int numCentroids; + private final int dimension; + private final float[] scratch; + private int currOrd = -1; + + OffHeapCentroidSupplier(IndexInput centroidsInput, int numCentroids, FieldInfo info) { + this.centroidsInput = centroidsInput; + this.numCentroids = numCentroids; + this.dimension = info.getVectorDimension(); + this.scratch = new float[dimension]; + } + + @Override + public int size() { + return numCentroids; + } + + @Override + public float[] centroid(int centroidOrdinal) throws IOException { + if (centroidOrdinal == currOrd) { + return scratch; + } + centroidsInput.seek((long) centroidOrdinal * dimension * Float.BYTES); + centroidsInput.readFloats(scratch, 0, dimension); + this.currOrd = centroidOrdinal; + return scratch; + } + } + + static class QuantizedCentroids implements QuantizedVectorValues { + private final CentroidSupplier supplier; + private final OptimizedScalarQuantizer quantizer; + private final byte[] quantizedVector; + private final int[] quantizedVectorScratch; + private final float[] floatVectorScratch; + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final float[] centroid; + private int currOrd = -1; + private IntToIntFunction ordTransformer = i -> i; + int size; + + QuantizedCentroids(CentroidSupplier supplier, int dimension, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.supplier = supplier; + this.quantizer = quantizer; + this.quantizedVector = new byte[dimension]; + this.floatVectorScratch = new float[dimension]; + this.quantizedVectorScratch = new int[dimension]; + this.centroid = centroid; + size = supplier.size(); + } + + @Override + public int count() { + return size; + } + + void reset(IntToIntFunction ordTransformer, int size) { + this.ordTransformer = ordTransformer; + this.currOrd = -1; + this.size = size; + this.corrections = null; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count() - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count()); + } + currOrd++; + float[] vector = supplier.centroid(ordTransformer.apply(currOrd)); + corrections = quantizer.scalarQuantize(vector, floatVectorScratch, quantizedVectorScratch, (byte) 7, centroid); + for (int i = 0; i < quantizedVectorScratch.length; i++) { + quantizedVector[i] = (byte) quantizedVectorScratch[i]; + } + return quantizedVector; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + return corrections; + } + } + + static class OnHeapQuantizedVectors implements QuantizedVectorValues { + private final FloatVectorValues vectorValues; + private final OptimizedScalarQuantizer quantizer; + private final byte[] quantizedVector; + private final int[] quantizedVectorScratch; + private final float[] floatVectorScratch; + private OptimizedScalarQuantizer.QuantizationResult corrections; + private float[] currentCentroid; + private IntToIntFunction ordTransformer = null; + private int currOrd = -1; + private int count; + + OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) { + this.vectorValues = vectorValues; + this.quantizer = quantizer; + this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.floatVectorScratch = new float[dimension]; + this.quantizedVectorScratch = new int[dimension]; + this.corrections = null; + } + + private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) { + this.currentCentroid = centroid; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + this.count = count; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count() - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count()); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + float[] vector = vectorValues.vectorValue(ord); + corrections = quantizer.scalarQuantize(vector, floatVectorScratch, quantizedVectorScratch, (byte) 1, currentCentroid); + BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector); + return quantizedVector; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call next first"); + } + return corrections; + } + } + + static class OffHeapQuantizedVectors implements QuantizedVectorValues { + private final IndexInput quantizedVectorsInput; + private final byte[] binaryScratch; + private final float[] corrections = new float[3]; + + private final int vectorByteSize; + private short bitSum; + private int currOrd = -1; + private int count; + private IntToBooleanFunction isOverspill = null; + private IntToIntFunction ordTransformer = null; + + OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { + this.quantizedVectorsInput = quantizedVectorsInput; + this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); + } + + private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) { + this.count = count; + this.isOverspill = isOverspill; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + boolean isOverspill = this.isOverspill.apply(currOrd); + return getVector(ord, isOverspill); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call readQuantizedVector first"); + } + return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum); + } + + byte[] getVector(int ord, boolean isOverspill) throws IOException { + readQuantizedVector(ord, isOverspill); + return binaryScratch; + } + + public void readQuantizedVector(int ord, boolean isOverspill) throws IOException { + long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0); + quantizedVectorsInput.seek(offset); + quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length); + quantizedVectorsInput.readFloats(corrections, 0, 3); + bitSum = quantizedVectorsInput.readShort(); + } + } +} diff --git a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat index ff02849e96c46..693b34fa5c01f 100644 --- a/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat +++ b/server/src/main/resources/META-INF/services/org.apache.lucene.codecs.KnnVectorsFormat @@ -8,3 +8,4 @@ org.elasticsearch.index.codec.vectors.es816.ES816HnswBinaryQuantizedVectorsForma org.elasticsearch.index.codec.vectors.es818.ES818BinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.es818.ES818HnswBinaryQuantizedVectorsFormat org.elasticsearch.index.codec.vectors.diskbbq.ES920DiskBBQVectorsFormat +org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java new file mode 100644 index 0000000000000..47fa8d0c2d381 --- /dev/null +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/diskbbq/next/ESNextDiskBBQVectorsFormatTests.java @@ -0,0 +1,293 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.index.codec.vectors.diskbbq.next; + +import com.carrotsearch.randomizedtesting.generators.RandomPicks; + +import org.apache.lucene.codecs.Codec; +import org.apache.lucene.codecs.FilterCodec; +import org.apache.lucene.codecs.KnnVectorsFormat; +import org.apache.lucene.codecs.KnnVectorsReader; +import org.apache.lucene.codecs.perfield.PerFieldKnnVectorsFormat; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.index.CodecReader; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.VectorEncoding; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.AcceptDocs; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.BaseKnnVectorsFormatTestCase; +import org.apache.lucene.tests.util.TestUtil; +import org.elasticsearch.common.logging.LogConfigurator; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.atomic.AtomicBoolean; + +import static java.lang.String.format; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER; +import static org.elasticsearch.index.codec.vectors.diskbbq.next.ESNextDiskBBQVectorsFormat.MIN_VECTORS_PER_CLUSTER; +import static org.hamcrest.Matchers.anEmptyMap; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.oneOf; + +public class ESNextDiskBBQVectorsFormatTests extends BaseKnnVectorsFormatTestCase { + + static { + LogConfigurator.loadLog4jPlugins(); + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + KnnVectorsFormat format; + + @Before + @Override + public void setUp() throws Exception { + if (rarely()) { + format = new ESNextDiskBBQVectorsFormat( + random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, ESNextDiskBBQVectorsFormat.MAX_VECTORS_PER_CLUSTER), + random().nextInt(8, ESNextDiskBBQVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER) + ); + } else { + // run with low numbers to force many clusters with parents + format = new ESNextDiskBBQVectorsFormat( + random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), + random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8) + ); + } + super.setUp(); + } + + @Override + protected VectorSimilarityFunction randomSimilarity() { + return RandomPicks.randomFrom( + random(), + List.of( + VectorSimilarityFunction.DOT_PRODUCT, + VectorSimilarityFunction.EUCLIDEAN, + VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT + ) + ); + } + + @Override + protected VectorEncoding randomVectorEncoding() { + return VectorEncoding.FLOAT32; + } + + @Override + public void testSearchWithVisitedLimit() { + // ivf doesn't enforce visitation limit + } + + @Override + protected Codec getCodec() { + return TestUtil.alwaysKnnVectorsFormat(format); + } + + @Override + protected void assertOffHeapByteSize(LeafReader r, String fieldName) throws IOException { + var fieldInfo = r.getFieldInfos().fieldInfo(fieldName); + + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader(fieldName); + } + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + long totalByteSize = offHeap.values().stream().mapToLong(Long::longValue).sum(); + // IVF doesn't report stats at the moment + assertThat(offHeap, anEmptyMap()); + assertThat(totalByteSize, equalTo(0L)); + } else { + throw new AssertionError("unexpected:" + r.getClass()); + } + } + + @Override + public void testAdvance() throws Exception { + // TODO re-enable with hierarchical IVF, clustering as it is is flaky + } + + public void testToString() { + FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { + @Override + public KnnVectorsFormat knnVectorsFormat() { + return new ESNextDiskBBQVectorsFormat(128, 4); + } + }; + String expectedPattern = "ESNextDiskBBQVectorsFormat(vectorPerCluster=128)"; + + var defaultScorer = format(Locale.ROOT, expectedPattern, "DefaultFlatVectorScorer"); + var memSegScorer = format(Locale.ROOT, expectedPattern, "Lucene99MemorySegmentFlatVectorsScorer"); + assertThat(customCodec.knnVectorsFormat().toString(), is(oneOf(defaultScorer, memSegScorer))); + } + + public void testLimits() { + expectThrows(IllegalArgumentException.class, () -> new ESNextDiskBBQVectorsFormat(MIN_VECTORS_PER_CLUSTER - 1, 16)); + expectThrows(IllegalArgumentException.class, () -> new ESNextDiskBBQVectorsFormat(MAX_VECTORS_PER_CLUSTER + 1, 16)); + expectThrows(IllegalArgumentException.class, () -> new ESNextDiskBBQVectorsFormat(128, MIN_CENTROIDS_PER_PARENT_CLUSTER - 1)); + expectThrows(IllegalArgumentException.class, () -> new ESNextDiskBBQVectorsFormat(128, MAX_CENTROIDS_PER_PARENT_CLUSTER + 1)); + } + + public void testSimpleOffHeapSize() throws IOException { + float[] vector = randomVector(random().nextInt(12, 500)); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + w.commit(); + try (IndexReader reader = DirectoryReader.open(w)) { + LeafReader r = getOnlyLeafReader(reader); + if (r instanceof CodecReader codecReader) { + KnnVectorsReader knnVectorsReader = codecReader.getVectorReader(); + if (knnVectorsReader instanceof PerFieldKnnVectorsFormat.FieldsReader fieldsReader) { + knnVectorsReader = fieldsReader.getFieldReader("f"); + } + var fieldInfo = r.getFieldInfos().fieldInfo("f"); + var offHeap = knnVectorsReader.getOffHeapByteSize(fieldInfo); + assertEquals(0, offHeap.size()); + } + } + } + } + + public void testFewVectorManyTimes() throws IOException { + int numDifferentVectors = random().nextInt(1, 20); + float[][] vectors = new float[numDifferentVectors][]; + int dimensions = random().nextInt(12, 500); + for (int i = 0; i < numDifferentVectors; i++) { + vectors[i] = randomVector(dimensions); + } + int numDocs = random().nextInt(100, 10_000); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + float[] vector = vectors[random().nextInt(numDifferentVectors)]; + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + w.commit(); + if (rarely()) { + w.forceMerge(1); + } + try (IndexReader reader = DirectoryReader.open(w)) { + List subReaders = reader.leaves(); + for (LeafReaderContext r : subReaders) { + LeafReader leafReader = r.reader(); + float[] vector = randomVector(dimensions); + TopDocs topDocs = leafReader.searchNearestVectors( + "f", + vector, + 10, + AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()), + Integer.MAX_VALUE + ); + assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length); + } + + } + } + } + + public void testOneRepeatedVector() throws IOException { + int dimensions = random().nextInt(12, 500); + float[] repeatedVector = randomVector(dimensions); + int numDocs = random().nextInt(100, 10_000); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int i = 0; i < numDocs; i++) { + float[] vector = random().nextInt(3) == 0 ? repeatedVector : randomVector(dimensions); + Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", vector, VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + w.commit(); + if (rarely()) { + w.forceMerge(1); + } + try (IndexReader reader = DirectoryReader.open(w)) { + List subReaders = reader.leaves(); + for (LeafReaderContext r : subReaders) { + LeafReader leafReader = r.reader(); + float[] vector = randomVector(dimensions); + TopDocs topDocs = leafReader.searchNearestVectors( + "f", + vector, + 10, + AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()), + Integer.MAX_VALUE + ); + assertEquals(Math.min(leafReader.maxDoc(), 10), topDocs.scoreDocs.length); + } + + } + } + } + + // this is a modified version of lucene's TestSearchWithThreads test case + public void testWithThreads() throws Exception { + final int numThreads = random().nextInt(2, 5); + final int numSearches = atLeast(100); + final int numDocs = atLeast(1000); + final int dimensions = random().nextInt(12, 500); + try (Directory dir = newDirectory(); IndexWriter w = new IndexWriter(dir, newIndexWriterConfig())) { + for (int docCount = 0; docCount < numDocs; docCount++) { + final Document doc = new Document(); + doc.add(new KnnFloatVectorField("f", randomVector(dimensions), VectorSimilarityFunction.EUCLIDEAN)); + w.addDocument(doc); + } + w.forceMerge(1); + try (IndexReader reader = DirectoryReader.open(w)) { + final AtomicBoolean failed = new AtomicBoolean(); + Thread[] threads = new Thread[numThreads]; + for (int threadID = 0; threadID < numThreads; threadID++) { + threads[threadID] = new Thread(() -> { + try { + long totSearch = 0; + for (; totSearch < numSearches && failed.get() == false; totSearch++) { + float[] vector = randomVector(dimensions); + LeafReader leafReader = getOnlyLeafReader(reader); + leafReader.searchNearestVectors( + "f", + vector, + 10, + AcceptDocs.fromLiveDocs(leafReader.getLiveDocs(), leafReader.maxDoc()), + Integer.MAX_VALUE + ); + } + assertTrue(totSearch > 0); + } catch (Exception exc) { + failed.set(true); + throw new RuntimeException(exc); + } + }); + threads[threadID].setDaemon(true); + } + + for (Thread t : threads) { + t.start(); + } + + for (Thread t : threads) { + t.join(); + } + } + } + } +}