diff --git a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java index 17257dcb73d59..29dbe5613469c 100644 --- a/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java +++ b/qa/vector/src/main/java/org/elasticsearch/test/knn/KnnIndexTester.java @@ -101,7 +101,7 @@ private static String formatIndexPath(CmdLineArgs args) { static Codec createCodec(CmdLineArgs args) { final KnnVectorsFormat format; if (args.indexType() == IndexType.IVF) { - format = new IVFVectorsFormat(args.ivfClusterSize()); + format = new IVFVectorsFormat(args.ivfClusterSize(), IVFVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER); } else { if (args.quantizeBits() == 1) { if (args.indexType() == IndexType.FLAT) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java index 47c6bb99eabb5..a00ee5c0e0205 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java @@ -29,8 +29,6 @@ import static org.apache.lucene.codecs.lucene102.Lucene102BinaryQuantizedVectorsFormat.QUERY_BITS; import static org.apache.lucene.index.VectorSimilarityFunction.COSINE; -import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN; -import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT; import static org.elasticsearch.index.codec.vectors.BQSpaceUtils.transposeHalfByte; import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; import static org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer.DEFAULT_LAMBDA; @@ -41,7 +39,9 @@ * brute force and then scores the top ones using the posting list. */ public class DefaultIVFVectorsReader extends IVFVectorsReader implements OffHeapStats { - private static final float FOUR_BIT_SCALE = 1f / ((1 << 4) - 1); + + // The percentage of centroids that are scored to keep recall + public static final double CENTROID_SAMPLING_PERCENTAGE = 0.2; public DefaultIVFVectorsReader(SegmentReadState state, FlatVectorsReader rawVectorsReader) throws IOException { super(state, rawVectorsReader); @@ -54,8 +54,12 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde final float globalCentroidDp = fieldEntry.globalCentroidDp(); final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); final int[] scratch = new int[targetQuery.length]; + float[] targetQueryCopy = ArrayUtil.copyArray(targetQuery); + if (fieldInfo.getVectorSimilarityFunction() == COSINE) { + VectorUtil.l2normalize(targetQueryCopy); + } final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize( - ArrayUtil.copyArray(targetQuery), + targetQueryCopy, scratch, (byte) 4, fieldEntry.globalCentroid() @@ -65,67 +69,227 @@ CentroidIterator getCentroidIterator(FieldInfo fieldInfo, int numCentroids, Inde quantized[i] = (byte) scratch[i]; } final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension()); - NeighborQueue queue = new NeighborQueue(fieldEntry.numCentroids(), true); centroids.seek(0L); - final float[] centroidCorrectiveValues = new float[3]; - for (int i = 0; i < numCentroids; i++) { - final float qcDist = scorer.int4DotProduct(quantized); - centroids.readFloats(centroidCorrectiveValues, 0, 3); - final int quantizedCentroidComponentSum = Short.toUnsignedInt(centroids.readShort()); - float score = int4QuantizedScore( - qcDist, + int numParents = centroids.readVInt(); + if (numParents > 0) { + return getCentroidIteratorWithParents( + fieldInfo, + centroids, + numParents, + numCentroids, + scorer, + quantized, queryParams, - fieldInfo.getVectorDimension(), - centroidCorrectiveValues, - quantizedCentroidComponentSum, - globalCentroidDp, - fieldInfo.getVectorSimilarityFunction() + globalCentroidDp ); - queue.add(i, score); } - final long offset = centroids.getFilePointer(); + return getCentroidIteratorNoParent(fieldInfo, centroids, numCentroids, scorer, quantized, queryParams, globalCentroidDp); + } + + private static CentroidIterator getCentroidIteratorNoParent( + FieldInfo fieldInfo, + IndexInput centroids, + int numCentroids, + ES91Int4VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryParams, + float globalCentroidDp + ) throws IOException { + final NeighborQueue neighborQueue = new NeighborQueue(numCentroids, true); + score( + neighborQueue, + numCentroids, + 0, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + new float[ES91Int4VectorsScorer.BULK_SIZE] + ); + long offset = centroids.getFilePointer(); return new CentroidIterator() { @Override public boolean hasNext() { - return queue.size() > 0; + return neighborQueue.size() > 0; } @Override public long nextPostingListOffset() throws IOException { - int centroidOrdinal = queue.pop(); + int centroidOrdinal = neighborQueue.pop(); centroids.seek(offset + (long) Long.BYTES * centroidOrdinal); return centroids.readLong(); } }; } - // TODO can we do this in off-heap blocks? - private float int4QuantizedScore( - float qcDist, + private static CentroidIterator getCentroidIteratorWithParents( + FieldInfo fieldInfo, + IndexInput centroids, + int numParents, + int numCentroids, + ES91Int4VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryParams, + float globalCentroidDp + ) throws IOException { + // build the three queues we are going to use + final NeighborQueue parentsQueue = new NeighborQueue(numParents, true); + final int maxChildrenSize = centroids.readVInt(); + final NeighborQueue currentParentQueue = new NeighborQueue(maxChildrenSize, true); + final int bufferSize = (int) Math.max(numCentroids * CENTROID_SAMPLING_PERCENTAGE, 1); + final NeighborQueue neighborQueue = new NeighborQueue(bufferSize, true); + // score the parents + final float[] scores = new float[ES91Int4VectorsScorer.BULK_SIZE]; + score( + parentsQueue, + numParents, + 0, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + scores + ); + final long centroidQuantizeSize = fieldInfo.getVectorDimension() + 3 * Float.BYTES + Short.BYTES; + final long offset = centroids.getFilePointer(); + final long childrenOffset = offset + (long) Long.BYTES * numParents; + // populate the children's queue by reading parents one by one + while (parentsQueue.size() > 0 && neighborQueue.size() < bufferSize) { + final int pop = parentsQueue.pop(); + populateOneChildrenGroup( + currentParentQueue, + centroids, + offset + 2L * Integer.BYTES * pop, + childrenOffset, + centroidQuantizeSize, + fieldInfo, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + scores + ); + while (currentParentQueue.size() > 0 && neighborQueue.size() < bufferSize) { + final float score = currentParentQueue.topScore(); + final int children = currentParentQueue.pop(); + neighborQueue.add(children, score); + } + } + final long childrenFileOffsets = childrenOffset + centroidQuantizeSize * numCentroids; + return new CentroidIterator() { + @Override + public boolean hasNext() { + return neighborQueue.size() > 0; + } + + @Override + public long nextPostingListOffset() throws IOException { + int centroidOrdinal = neighborQueue.pop(); + updateQueue(); // add one children if available so the queue remains fully populated + centroids.seek(childrenFileOffsets + (long) Long.BYTES * centroidOrdinal); + return centroids.readLong(); + } + + private void updateQueue() throws IOException { + if (currentParentQueue.size() > 0) { + // add a children from the current parent queue + float score = currentParentQueue.topScore(); + int children = currentParentQueue.pop(); + neighborQueue.add(children, score); + } else if (parentsQueue.size() > 0) { + // add a new parent from the current parent queue + int pop = parentsQueue.pop(); + populateOneChildrenGroup( + currentParentQueue, + centroids, + offset + 2L * Integer.BYTES * pop, + childrenOffset, + centroidQuantizeSize, + fieldInfo, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + scores + ); + updateQueue(); + } + } + }; + } + + private static void populateOneChildrenGroup( + NeighborQueue neighborQueue, + IndexInput centroids, + long parentOffset, + long childrenOffset, + long centroidQuantizeSize, + FieldInfo fieldInfo, + ES91Int4VectorsScorer scorer, + byte[] quantizeQuery, + OptimizedScalarQuantizer.QuantizationResult queryParams, + float globalCentroidDp, + float[] scores + ) throws IOException { + centroids.seek(parentOffset); + int childrenOrdinal = centroids.readInt(); + int numChildren = centroids.readInt(); + centroids.seek(childrenOffset + centroidQuantizeSize * childrenOrdinal); + score( + neighborQueue, + numChildren, + childrenOrdinal, + scorer, + quantizeQuery, + queryParams, + globalCentroidDp, + fieldInfo.getVectorSimilarityFunction(), + scores + ); + } + + private static void score( + NeighborQueue neighborQueue, + int size, + int scoresOffset, + ES91Int4VectorsScorer scorer, + byte[] quantizeQuery, OptimizedScalarQuantizer.QuantizationResult queryCorrections, - int dims, - float[] targetCorrections, - int targetComponentSum, float centroidDp, - VectorSimilarityFunction similarityFunction - ) { - float ax = targetCorrections[0]; - float lx = (targetCorrections[1] - ax) * FOUR_BIT_SCALE; - float ay = queryCorrections.lowerInterval(); - float ly = (queryCorrections.upperInterval() - ay) * FOUR_BIT_SCALE; - float y1 = queryCorrections.quantizedComponentSum(); - float score = ax * ay * dims + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * qcDist; - if (similarityFunction == EUCLIDEAN) { - score = queryCorrections.additionalCorrection() + targetCorrections[2] - 2 * score; - return Math.max(1 / (1f + score), 0); - } else { - // For cosine and max inner product, we need to apply the additional correction, which is - // assumed to be the non-centered dot-product between the vector and the centroid - score += queryCorrections.additionalCorrection() + targetCorrections[2] - centroidDp; - if (similarityFunction == MAXIMUM_INNER_PRODUCT) { - return VectorUtil.scaleMaxInnerProductScore(score); + VectorSimilarityFunction similarityFunction, + float[] scores + ) throws IOException { + int limit = size - ES91Int4VectorsScorer.BULK_SIZE + 1; + int i = 0; + for (; i < limit; i += ES91Int4VectorsScorer.BULK_SIZE) { + scorer.scoreBulk( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp, + scores + ); + for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) { + neighborQueue.add(scoresOffset + i + j, scores[j]); } - return Math.max((1f + score) / 2f, 0); + } + + for (; i < size; i++) { + float score = scorer.score( + quantizeQuery, + queryCorrections.lowerInterval(), + queryCorrections.upperInterval(), + queryCorrections.quantizedComponentSum(), + queryCorrections.additionalCorrection(), + similarityFunction, + centroidDp + ); + neighborQueue.add(scoresOffset + i, score); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index f47ecc549831a..58f09cf70d4bd 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -26,11 +26,14 @@ import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; +import org.elasticsearch.simdvec.ES91Int4VectorsScorer; import org.elasticsearch.simdvec.ES91OSQVectorsScorer; import java.io.IOException; +import java.io.UncheckedIOException; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.util.AbstractList; import java.util.Arrays; /** @@ -42,10 +45,17 @@ public class DefaultIVFVectorsWriter extends IVFVectorsWriter { private static final Logger logger = LogManager.getLogger(DefaultIVFVectorsWriter.class); private final int vectorPerCluster; + private final int centroidsPerParentCluster; - public DefaultIVFVectorsWriter(SegmentWriteState state, FlatVectorsWriter rawVectorDelegate, int vectorPerCluster) throws IOException { + public DefaultIVFVectorsWriter( + SegmentWriteState state, + FlatVectorsWriter rawVectorDelegate, + int vectorPerCluster, + int centroidsPerParentCluster + ) throws IOException { super(state, rawVectorDelegate); this.vectorPerCluster = vectorPerCluster; + this.centroidsPerParentCluster = centroidsPerParentCluster; } @Override @@ -288,34 +298,136 @@ void writeCentroids( LongValues offsets, IndexOutput centroidOutput ) throws IOException { - - final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); - int[] quantizedScratch = new int[fieldInfo.getVectorDimension()]; - float[] centroidScratch = new float[fieldInfo.getVectorDimension()]; - final byte[] quantized = new byte[fieldInfo.getVectorDimension()]; // TODO do we want to store these distances as well for future use? // TODO: sort centroids by global centroid (was doing so previously here) // TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned - for (int i = 0; i < centroidSupplier.size(); i++) { - float[] centroid = centroidSupplier.centroid(i); - System.arraycopy(centroid, 0, centroidScratch, 0, centroid.length); - OptimizedScalarQuantizer.QuantizationResult result = osq.scalarQuantize( - centroidScratch, - quantizedScratch, - (byte) 4, - globalCentroid - ); - for (int j = 0; j < quantizedScratch.length; j++) { - quantized[j] = (byte) quantizedScratch[j]; + if (centroidSupplier.size() > centroidsPerParentCluster * centroidsPerParentCluster) { + writeCentroidsWithParents(fieldInfo, centroidSupplier, globalCentroid, offsets, centroidOutput); + } else { + writeCentroidsWithoutParents(fieldInfo, centroidSupplier, globalCentroid, offsets, centroidOutput); + } + } + + private void writeCentroidsWithParents( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + LongValues offsets, + IndexOutput centroidOutput + ) throws IOException { + DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter( + ES91Int4VectorsScorer.BULK_SIZE, + centroidOutput + ); + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + final CentroidGroups centroidGroups = buildCentroidGroups(fieldInfo, centroidSupplier); + centroidOutput.writeVInt(centroidGroups.centroids.length); + centroidOutput.writeVInt(centroidGroups.maxVectorsPerCentroidLength); + QuantizedCentroids parentQuantizeCentroid = new QuantizedCentroids( + new OnHeapCentroidSupplier(centroidGroups.centroids), + fieldInfo.getVectorDimension(), + osq, + globalCentroid + ); + bulkWriter.writeVectors(parentQuantizeCentroid); + int offset = 0; + for (int i = 0; i < centroidGroups.centroids().length; i++) { + centroidOutput.writeInt(offset); + centroidOutput.writeInt(centroidGroups.vectors()[i].length); + offset += centroidGroups.vectors()[i].length; + } + + QuantizedCentroids childrenQuantizeCentroid = new QuantizedCentroids( + centroidSupplier, + fieldInfo.getVectorDimension(), + osq, + globalCentroid + ); + for (int i = 0; i < centroidGroups.centroids().length; i++) { + final int[] centroidAssignments = centroidGroups.vectors()[i]; + childrenQuantizeCentroid.reset(idx -> centroidAssignments[idx], centroidAssignments.length); + bulkWriter.writeVectors(childrenQuantizeCentroid); + } + // write the centroid offsets at the end of the file + for (int i = 0; i < centroidGroups.centroids().length; i++) { + final int[] centroidAssignments = centroidGroups.vectors()[i]; + for (int assignment : centroidAssignments) { + centroidOutput.writeLong(offsets.get(assignment)); } - writeQuantizedValue(centroidOutput, quantized, result); } + } + + private void writeCentroidsWithoutParents( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + float[] globalCentroid, + LongValues offsets, + IndexOutput centroidOutput + ) throws IOException { + centroidOutput.writeVInt(0); + DiskBBQBulkWriter.FourBitDiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.FourBitDiskBBQBulkWriter( + ES91Int4VectorsScorer.BULK_SIZE, + centroidOutput + ); + final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + QuantizedCentroids quantizedCentroids = new QuantizedCentroids( + centroidSupplier, + fieldInfo.getVectorDimension(), + osq, + globalCentroid + ); + bulkWriter.writeVectors(quantizedCentroids); // write the centroid offsets at the end of the file for (int i = 0; i < centroidSupplier.size(); i++) { centroidOutput.writeLong(offsets.get(i)); } } + private record CentroidGroups(float[][] centroids, int[][] vectors, int maxVectorsPerCentroidLength) {} + + private CentroidGroups buildCentroidGroups(FieldInfo fieldInfo, CentroidSupplier centroidSupplier) throws IOException { + final FloatVectorValues floatVectorValues = FloatVectorValues.fromFloats(new AbstractList<>() { + @Override + public float[] get(int index) { + try { + return centroidSupplier.centroid(index); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override + public int size() { + return centroidSupplier.size(); + } + }, fieldInfo.getVectorDimension()); + // we use the HierarchicalKMeans to partition the space of all vectors across merging segments + // this are small numbers so we run it wih all the centroids. + final KMeansResult kMeansResult = new HierarchicalKMeans( + fieldInfo.getVectorDimension(), + 6, + floatVectorValues.size(), + floatVectorValues.size(), + -1 // disable SOAR assignments + ).cluster(floatVectorValues, centroidsPerParentCluster); + final int[] centroidVectorCount = new int[kMeansResult.centroids().length]; + for (int i = 0; i < kMeansResult.assignments().length; i++) { + centroidVectorCount[kMeansResult.assignments()[i]]++; + } + final int[][] vectorsPerCentroid = new int[kMeansResult.centroids().length][]; + int maxVectorsPerCentroidLength = 0; + for (int i = 0; i < kMeansResult.centroids().length; i++) { + vectorsPerCentroid[i] = new int[centroidVectorCount[i]]; + maxVectorsPerCentroidLength = Math.max(maxVectorsPerCentroidLength, centroidVectorCount[i]); + } + Arrays.fill(centroidVectorCount, 0); + for (int i = 0; i < kMeansResult.assignments().length; i++) { + final int c = kMeansResult.assignments()[i]; + vectorsPerCentroid[c][centroidVectorCount[c]++] = i; + } + return new CentroidGroups(kMeansResult.centroids(), vectorsPerCentroid, maxVectorsPerCentroidLength); + } + /** * Calculate the centroids for the given field. * We use the {@link HierarchicalKMeans} algorithm to partition the space of all vectors across merging segments @@ -415,6 +527,63 @@ interface IntToBooleanFunction { boolean apply(int ord); } + static class QuantizedCentroids implements QuantizedVectorValues { + private final CentroidSupplier supplier; + private final OptimizedScalarQuantizer quantizer; + private final byte[] quantizedVector; + private final int[] quantizedVectorScratch; + private final float[] floatVectorScratch; + private OptimizedScalarQuantizer.QuantizationResult corrections; + private final float[] centroid; + private int currOrd = -1; + private IntToIntFunction ordTransformer = i -> i; + int size; + + QuantizedCentroids(CentroidSupplier supplier, int dimension, OptimizedScalarQuantizer quantizer, float[] centroid) { + this.supplier = supplier; + this.quantizer = quantizer; + this.quantizedVector = new byte[dimension]; + this.floatVectorScratch = new float[dimension]; + this.quantizedVectorScratch = new int[dimension]; + this.centroid = centroid; + size = supplier.size(); + } + + @Override + public int count() { + return size; + } + + void reset(IntToIntFunction ordTransformer, int size) { + this.ordTransformer = ordTransformer; + this.currOrd = -1; + this.size = size; + this.corrections = null; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count() - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count()); + } + currOrd++; + float[] vector = supplier.centroid(ordTransformer.apply(currOrd)); + // Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice + // due to overspill, so we copy the vector to a scratch array + System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length); + corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 4, centroid); + for (int i = 0; i < quantizedVectorScratch.length; i++) { + quantizedVector[i] = (byte) quantizedVectorScratch[i]; + } + return quantizedVector; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + return corrections; + } + } + static class OnHeapQuantizedVectors implements QuantizedVectorValues { private final FloatVectorValues vectorValues; private final OptimizedScalarQuantizer quantizer; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 662878270ea09..9da77fb77661a 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -84,4 +84,34 @@ void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOEx } } } + + static class FourBitDiskBBQBulkWriter extends DiskBBQBulkWriter { + private final OptimizedScalarQuantizer.QuantizationResult[] corrections; + + FourBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + super(bulkSize, out); + this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; + } + + @Override + void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException { + int limit = qvv.count() - bulkSize + 1; + int i = 0; + for (; i < limit; i += bulkSize) { + for (int j = 0; j < bulkSize; j++) { + byte[] qv = qvv.next(); + corrections[j] = qvv.getCorrections(); + out.writeBytes(qv, qv.length); + } + writeCorrections(corrections, out); + } + // write tail + for (; i < qvv.count(); ++i) { + byte[] qv = qvv.next(); + OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); + out.writeBytes(qv, qv.length); + writeCorrection(correction, out); + } + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java index 7a18558703423..aa8921cee24c4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormat.java @@ -65,10 +65,14 @@ public class IVFVectorsFormat extends KnnVectorsFormat { public static final int DEFAULT_VECTORS_PER_CLUSTER = 384; public static final int MIN_VECTORS_PER_CLUSTER = 64; public static final int MAX_VECTORS_PER_CLUSTER = 1 << 16; // 65536 + public static final int DEFAULT_CENTROIDS_PER_PARENT_CLUSTER = 16; + public static final int MIN_CENTROIDS_PER_PARENT_CLUSTER = 2; + public static final int MAX_CENTROIDS_PER_PARENT_CLUSTER = 1 << 8; // 256 private final int vectorPerCluster; + private final int centroidsPerParentCluster; - public IVFVectorsFormat(int vectorPerCluster) { + public IVFVectorsFormat(int vectorPerCluster, int centroidsPerParentCluster) { super(NAME); if (vectorPerCluster < MIN_VECTORS_PER_CLUSTER || vectorPerCluster > MAX_VECTORS_PER_CLUSTER) { throw new IllegalArgumentException( @@ -80,17 +84,28 @@ public IVFVectorsFormat(int vectorPerCluster) { + vectorPerCluster ); } + if (centroidsPerParentCluster < MIN_CENTROIDS_PER_PARENT_CLUSTER || centroidsPerParentCluster > MAX_CENTROIDS_PER_PARENT_CLUSTER) { + throw new IllegalArgumentException( + "centroidsPerParentCluster must be between " + + MIN_CENTROIDS_PER_PARENT_CLUSTER + + " and " + + MAX_CENTROIDS_PER_PARENT_CLUSTER + + ", got: " + + centroidsPerParentCluster + ); + } this.vectorPerCluster = vectorPerCluster; + this.centroidsPerParentCluster = centroidsPerParentCluster; } /** Constructs a format using the given graph construction parameters and scalar quantization. */ public IVFVectorsFormat() { - this(DEFAULT_VECTORS_PER_CLUSTER); + this(DEFAULT_VECTORS_PER_CLUSTER, DEFAULT_CENTROIDS_PER_PARENT_CLUSTER); } @Override public KnnVectorsWriter fieldsWriter(SegmentWriteState state) throws IOException { - return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster); + return new DefaultIVFVectorsWriter(state, rawVectorFormat.fieldsWriter(state), vectorPerCluster, centroidsPerParentCluster); } @Override diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java index fc13a4b9faa1a..22a78cfbae835 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/HierarchicalKMeans.java @@ -34,7 +34,7 @@ public HierarchicalKMeans(int dimension) { this(dimension, MAX_ITERATIONS_DEFAULT, SAMPLES_PER_CLUSTER_DEFAULT, MAXK, DEFAULT_SOAR_LAMBDA); } - HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) { + public HierarchicalKMeans(int dimension, int maxIterations, int samplesPerCluster, int clustersPerNeighborhood, float soarLambda) { this.dimension = dimension; this.maxIterations = maxIterations; this.samplesPerCluster = samplesPerCluster; @@ -79,7 +79,7 @@ public KMeansResult cluster(FloatVectorValues vectors, int targetSize) throws IO if (kMeansIntermediate.centroids().length > 1 && kMeansIntermediate.centroids().length < vectors.size()) { int localSampleSize = Math.min(kMeansIntermediate.centroids().length * samplesPerCluster / 2, vectors.size()); KMeansLocal kMeansLocal = new KMeansLocal(localSampleSize, maxIterations); - kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, DEFAULT_SOAR_LAMBDA); + kMeansLocal.cluster(vectors, kMeansIntermediate, clustersPerNeighborhood, soarLambda); } return kMeansIntermediate; diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index a3be558128577..a1e480fb73266 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -307,7 +307,7 @@ private void doCluster(FloatVectorValues vectors, KMeansIntermediate kMeansInter neighborhoods = computeNeighborhoods(centroids, clustersPerNeighborhood); } cluster(vectors, kMeansIntermediate, neighborhoods); - if (neighborAware) { + if (neighborAware && soarLambda >= 0) { assert kMeansIntermediate.soarAssignments().length == 0; kMeansIntermediate.setSoarAssignments(new int[vectors.size()]); assignSpilled(vectors, kMeansIntermediate, neighborhoods, soarLambda); diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 9019edc435eaf..c9c14d027ebfd 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -2312,7 +2312,7 @@ static class BBQIVFIndexOptions extends QuantizedIndexOptions { @Override KnnVectorsFormat getVectorsFormat(ElementType elementType) { assert elementType == ElementType.FLOAT; - return new IVFVectorsFormat(clusterSize); + return new IVFVectorsFormat(clusterSize, IVFVectorsFormat.DEFAULT_CENTROIDS_PER_PARENT_CLUSTER); } @Override diff --git a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java index 8499aa9a17320..2c0d2f3fc7449 100644 --- a/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java +++ b/server/src/test/java/org/elasticsearch/index/codec/vectors/IVFVectorsFormatTests.java @@ -37,7 +37,9 @@ import java.util.concurrent.atomic.AtomicBoolean; import static java.lang.String.format; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER; import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER; +import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER; import static org.elasticsearch.index.codec.vectors.IVFVectorsFormat.MIN_VECTORS_PER_CLUSTER; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.oneOf; @@ -53,7 +55,18 @@ public class IVFVectorsFormatTests extends BaseKnnVectorsFormatTestCase { @Before @Override public void setUp() throws Exception { - format = new IVFVectorsFormat(random().nextInt(MIN_VECTORS_PER_CLUSTER, IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER)); + if (rarely()) { + format = new IVFVectorsFormat( + random().nextInt(2 * MIN_VECTORS_PER_CLUSTER, IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER), + random().nextInt(8, IVFVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER) + ); + } else { + // run with low numbers to force many clusters with parents + format = new IVFVectorsFormat( + random().nextInt(MIN_VECTORS_PER_CLUSTER, 2 * MIN_VECTORS_PER_CLUSTER), + random().nextInt(MIN_CENTROIDS_PER_PARENT_CLUSTER, 8) + ); + } super.setUp(); } @@ -93,7 +106,7 @@ public void testToString() { FilterCodec customCodec = new FilterCodec("foo", Codec.getDefault()) { @Override public KnnVectorsFormat knnVectorsFormat() { - return new IVFVectorsFormat(128); + return new IVFVectorsFormat(128, 4); } }; String expectedPattern = "IVFVectorsFormat(vectorPerCluster=128)"; @@ -104,8 +117,10 @@ public KnnVectorsFormat knnVectorsFormat() { } public void testLimits() { - expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MIN_VECTORS_PER_CLUSTER - 1)); - expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MAX_VECTORS_PER_CLUSTER + 1)); + expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MIN_VECTORS_PER_CLUSTER - 1, 16)); + expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(MAX_VECTORS_PER_CLUSTER + 1, 16)); + expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(128, MIN_CENTROIDS_PER_PARENT_CLUSTER - 1)); + expectThrows(IllegalArgumentException.class, () -> new IVFVectorsFormat(128, MAX_CENTROIDS_PER_PARENT_CLUSTER + 1)); } public void testSimpleOffHeapSize() throws IOException { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java index f73d1e5a31999..ce08d631399d6 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractDiversifyingChildrenIVFKnnVectorQueryTestCase.java @@ -93,7 +93,10 @@ static Document makeParent(int[] children) { @Before public void setUp() throws Exception { super.setUp(); - format = new IVFVectorsFormat(random().nextInt(IVFVectorsFormat.MIN_VECTORS_PER_CLUSTER, IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER)); + format = new IVFVectorsFormat( + random().nextInt(IVFVectorsFormat.MIN_VECTORS_PER_CLUSTER, IVFVectorsFormat.MAX_VECTORS_PER_CLUSTER), + random().nextInt(IVFVectorsFormat.MIN_CENTROIDS_PER_PARENT_CLUSTER, IVFVectorsFormat.MAX_CENTROIDS_PER_PARENT_CLUSTER) + ); } abstract Query getDiversifyingChildrenKnnQuery( diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java index 70fd11c97a8c4..e602f9098b602 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractIVFKnnVectorQueryTestCase.java @@ -98,7 +98,7 @@ abstract class AbstractIVFKnnVectorQueryTestCase extends LuceneTestCase { @Before public void setUp() throws Exception { super.setUp(); - format = new IVFVectorsFormat(128); + format = new IVFVectorsFormat(128, 4); } abstract AbstractIVFKnnVectorQuery getKnnVectorQuery(String field, float[] query, int k, Query queryFilter, int nProbe);