diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java index b5e1276d0747a..e92ece41077a6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -9,10 +9,11 @@ package org.elasticsearch.index.codec.vectors; -record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) { +record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { - CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) { - this(centroids.length, centroids, assignmentsByCluster); - assert centroids.length == assignmentsByCluster.length; + CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { + this(centroids.length, centroids, assignments, overspillAssignments); + assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0 + : "assignments and overspillAssignments must have the same length"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 84abb1bea543f..e94b728f934e4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -14,9 +14,11 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.IntToIntFunction; import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.logging.LogManager; @@ -49,32 +51,58 @@ long[] buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, - int[][] assignmentsByCluster + int[] assignments, + int[] overspillAssignments ) throws IOException { + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != -1) { + assignmentsByCluster[s][centroidVectorCount[s]++] = i; + } + } + } // write the posting lists final long[] offsets = new long[centroidSupplier.size()]; - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); DocIdsWriter docIdsWriter = new DocIdsWriter(); - DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( - ES91OSQVectorsScorer.BULK_SIZE, - quantizer, + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, - postingsOutput + fieldInfo.getVectorDimension(), + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) ); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); - // TODO: add back in sorting vectors by distance to centroid int[] cluster = assignmentsByCluster[c]; // TODO align??? offsets[c] = postingsOutput.getFilePointer(); int size = cluster.length; postingsOutput.writeVInt(size); postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); + bulkWriter.writeVectors(onHeapQuantizedVectors); } if (logger.isDebugEnabled()) { @@ -84,6 +112,124 @@ long[] buildAndWritePostingsLists( return offsets; } + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState, + int[] assignments, + int[] overspillAssignments + ) throws IOException { + // first, quantize all the vectors into a temporary file + String quantizedVectorsTempName = null; + IndexOutput quantizedVectorsTemp = null; + boolean success = false; + try { + quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT); + quantizedVectorsTempName = quantizedVectorsTemp.getName(); + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + int[] quantized = new int[fieldInfo.getVectorDimension()]; + byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; + float[] overspillScratch = new float[fieldInfo.getVectorDimension()]; + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + float[] centroid = centroidSupplier.centroid(c); + float[] vector = floatVectorValues.vectorValue(i); + boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1; + // if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector + // so, make a copy of the vector to avoid mutating it + if (overspill) { + System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension()); + } + + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + if (overspill) { + int s = overspillAssignments[i]; + // write the overspill vector as well + result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s)); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + } else { + // write a zero vector for the overspill + Arrays.fill(binary, (byte) 0); + OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0); + writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult); + } + } + // close the temporary file so we can read it later + quantizedVectorsTemp.close(); + success = true; + } finally { + if (success == false && quantizedVectorsTemp != null) { + mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); + } + } + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + isOverspillByCluster[c] = new boolean[centroidVectorCount[c]]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != -1) { + assignmentsByCluster[s][centroidVectorCount[s]] = i; + isOverspillByCluster[s][centroidVectorCount[s]++] = true; + } + } + } + // now we can read the quantized vectors from the temporary file + try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { + final long[] offsets = new long[centroidSupplier.size()]; + OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( + quantizedVectorsInput, + fieldInfo.getVectorDimension() + ); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + for (int c = 0; c < centroidSupplier.size(); c++) { + float[] centroid = centroidSupplier.centroid(c); + int[] cluster = assignmentsByCluster[c]; + boolean[] isOverspill = isOverspillByCluster[c]; + // TODO align??? + offsets[c] = postingsOutput.getFilePointer(); + int size = cluster.length; + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + bulkWriter.writeVectors(offHeapQuantizedVectors); + } + + if (logger.isDebugEnabled()) { + printClusterQualityStatistics(assignmentsByCluster); + } + return offsets; + } + } + private static void printClusterQualityStatistics(int[][] clusters) { float min = Float.MAX_VALUE; float max = Float.MIN_VALUE; @@ -210,33 +356,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) { float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); int[] soarAssignments = kMeansResult.soarAssignments(); - int[] centroidVectorCount = new int[centroids.length]; - for (int i = 0; i < assignments.length; i++) { - centroidVectorCount[assignments[i]]++; - // if soar assignments are present, count them as well - if (soarAssignments.length > i && soarAssignments[i] != -1) { - centroidVectorCount[soarAssignments[i]]++; - } - } - - int[][] assignmentsByCluster = new int[centroids.length][]; - for (int c = 0; c < centroids.length; c++) { - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; - } - Arrays.fill(centroidVectorCount, 0); - - for (int i = 0; i < assignments.length; i++) { - int c = assignments[i]; - assignmentsByCluster[c][centroidVectorCount[c]++] = i; - // if soar assignments are present, add them to the cluster as well - if (soarAssignments.length > i) { - int s = soarAssignments[i]; - if (s != -1) { - assignmentsByCluster[s][centroidVectorCount[s]++] = i; - } - } - } - return new CentroidAssignments(centroids, assignmentsByCluster); + return new CentroidAssignments(centroids, assignments, soarAssignments); } static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) @@ -281,4 +401,132 @@ public float[] centroid(int centroidOrdinal) throws IOException { return scratch; } } + + interface QuantizedVectorValues { + int count(); + + byte[] next() throws IOException; + + OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException; + } + + interface IntToBooleanFunction { + boolean apply(int ord); + } + + static class OnHeapQuantizedVectors implements QuantizedVectorValues { + private final FloatVectorValues vectorValues; + private final OptimizedScalarQuantizer quantizer; + private final byte[] quantizedVector; + private final int[] quantizedVectorScratch; + private OptimizedScalarQuantizer.QuantizationResult corrections; + private float[] currentCentroid; + private IntToIntFunction ordTransformer = null; + private int currOrd = -1; + private int count; + + OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) { + this.vectorValues = vectorValues; + this.quantizer = quantizer; + this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.quantizedVectorScratch = new int[dimension]; + this.corrections = null; + } + + private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) { + this.currentCentroid = centroid; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + this.count = count; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count() - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count()); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + float[] vector = vectorValues.vectorValue(ord); + corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid); + BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector); + return quantizedVector; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call next first"); + } + return corrections; + } + } + + static class OffHeapQuantizedVectors implements QuantizedVectorValues { + private final IndexInput quantizedVectorsInput; + private final byte[] binaryScratch; + private final float[] corrections = new float[3]; + + private final int vectorByteSize; + private short bitSum; + private int currOrd = -1; + private int count; + private IntToBooleanFunction isOverspill = null; + private IntToIntFunction ordTransformer = null; + + OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { + this.quantizedVectorsInput = quantizedVectorsInput; + this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); + } + + private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) { + this.count = count; + this.isOverspill = isOverspill; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + boolean isOverspill = this.isOverspill.apply(currOrd); + return getVector(ord, isOverspill); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call readQuantizedVector first"); + } + return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum); + } + + byte[] getVector(int ord, boolean isOverspill) throws IOException { + readQuantizedVector(ord, isOverspill); + return binaryScratch; + } + + public void readQuantizedVector(int ord, boolean isOverspill) throws IOException { + long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0); + quantizedVectorsInput.seek(offset); + quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length); + quantizedVectorsInput.readFloats(corrections, 0, 3); + bitSum = quantizedVectorsInput.readShort(); + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 6974cd50d4abc..662878270ea09 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -9,34 +9,25 @@ package org.elasticsearch.index.codec.vectors; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.hnsw.IntToIntFunction; import java.io.IOException; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary; - /** * Base class for bulk writers that write vectors to disk using the BBQ encoding. * This class provides the structure for writing vectors in bulk, with specific * implementations for different bit sizes strategies. */ -public abstract class DiskBBQBulkWriter { +abstract class DiskBBQBulkWriter { protected final int bulkSize; - protected final OptimizedScalarQuantizer quantizer; protected final IndexOutput out; - protected final FloatVectorValues fvv; - protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { + protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) { this.bulkSize = bulkSize; - this.quantizer = quantizer; this.out = out; - this.fvv = fvv; } - public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException; + abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException; private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { @@ -64,39 +55,31 @@ private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult out.writeShort((short) targetComponentSum); } - public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { - private final byte[] binarized; - private final int[] initQuantized; + static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; - public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { - super(bulkSize, quantizer, fvv, out); - this.binarized = new byte[discretize(fvv.dimension(), 64) / 8]; - this.initQuantized = new int[fvv.dimension()]; + OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + super(bulkSize, out); this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } @Override - public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException { - int limit = count - bulkSize + 1; + void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException { + int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { for (int j = 0; j < bulkSize; j++) { - int ord = ords.apply(i + j); - float[] fv = fvv.vectorValue(ord); - corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); - packAsBinary(initQuantized, binarized); - out.writeBytes(binarized, binarized.length); + byte[] qv = qvv.next(); + corrections[j] = qvv.getCorrections(); + out.writeBytes(qv, qv.length); } writeCorrections(corrections, out); } // write tail - for (; i < count; ++i) { - int ord = ords.apply(i); - float[] fv = fvv.vectorValue(ord); - OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); - packAsBinary(initQuantized, binarized); - out.writeBytes(binarized, binarized.length); + for (; i < qvv.count(); ++i) { + byte[] qv = qvv.next(); + OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); + out.writeBytes(qv, qv.length); writeCorrection(correction, out); } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index e6da0ae1caff0..be7a60a3db893 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -139,7 +139,18 @@ abstract long[] buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, - int[][] assignmentsByCluster + int[] assignments, + int[] overspillAssignments + ) throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState, + int[] assignments, + int[] overspillAssignments ) throws IOException; abstract CentroidSupplier createCentroidSupplier( @@ -174,7 +185,8 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { centroidSupplier, floatVectorValues, ivfClusters, - centroidAssignments.assignmentsByCluster() + centroidAssignments.assignments(), + centroidAssignments.overspillAssignments() ); // write posting lists writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); @@ -284,7 +296,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws final long centroidOffset; final long centroidLength; final int numCentroids; - final int[][] assignmentsByCluster; + final int[] assignments; + final int[] overspillAssignments; final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; String centroidTempName = null; IndexOutput centroidTemp = null; @@ -300,7 +313,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws calculatedGlobalCentroid ); numCentroids = centroidAssignments.numCentroids(); - assignmentsByCluster = centroidAssignments.assignmentsByCluster(); + assignments = centroidAssignments.assignments(); + overspillAssignments = centroidAssignments.overspillAssignments(); success = true; } finally { if (success == false && centroidTempName != null) { @@ -337,7 +351,9 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws centroidSupplier, floatVectorValues, ivfClusters, - assignmentsByCluster + mergeState, + assignments, + overspillAssignments ); assert offsets.length == centroidSupplier.size(); writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid);