From 5dcf12d06e672e69e989ce1ef8d08ead24fd9c7c Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Mon, 7 Jul 2025 16:44:39 -0400 Subject: [PATCH 1/5] Make postings list building more IO friendly --- .../codec/vectors/CentroidAssignments.java | 8 +- .../vectors/DefaultIVFVectorsWriter.java | 188 +++++++++++++++--- .../index/codec/vectors/IVFVectorsWriter.java | 26 ++- 3 files changed, 185 insertions(+), 37 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java index b5e1276d0747a..b76c78f7f4284 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -9,10 +9,10 @@ package org.elasticsearch.index.codec.vectors; -record CentroidAssignments(int numCentroids, float[][] centroids, int[][] assignmentsByCluster) { +record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignments, int[] overspillAssignments) { - CentroidAssignments(float[][] centroids, int[][] assignmentsByCluster) { - this(centroids.length, centroids, assignmentsByCluster); - assert centroids.length == assignmentsByCluster.length; + CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { + this(centroids.length, centroids, assignments, overspillAssignments); + assert assignments.length == overspillAssignments.length : "assignments and overspillAssignments must have the same length"; } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 84abb1bea543f..32b69246d71da 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -14,6 +14,7 @@ import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.index.MergeState; import org.apache.lucene.index.SegmentWriteState; +import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.VectorUtil; @@ -49,8 +50,35 @@ long[] buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, - int[][] assignmentsByCluster + int[] assignments, + int[] overspillAssignments ) throws IOException { + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != -1) { + assignmentsByCluster[s][centroidVectorCount[s]++] = i; + } + } + } // write the posting lists final long[] offsets = new long[centroidSupplier.size()]; OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); @@ -84,6 +112,92 @@ long[] buildAndWritePostingsLists( return offsets; } + @Override + long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState, + int[] assignments, + int[] overspillAssignments + ) throws IOException { + // first, quantize all the vectors into a temporary file + String quantizedVectorsTempName = null; + IndexOutput quantizedVectorsTemp = null; + boolean success = false; + try { + quantizedVectorsTemp = mergeState.segmentInfo.dir.createTempOutput(mergeState.segmentInfo.name, "qvec_", IOContext.DEFAULT); + quantizedVectorsTempName = quantizedVectorsTemp.getName(); + OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); + int[] quantized = new int[fieldInfo.getVectorDimension()]; + byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + float[] centroid = centroidSupplier.centroid(c); + float[] vector = floatVectorValues.vectorValue(i); + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1; + if (overspill) { + int s = overspillAssignments[i]; + // write the overspill vector as well + result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroidSupplier.centroid(s)); + BQVectorUtils.packAsBinary(quantized, binary); + writeQuantizedValue(quantizedVectorsTemp, binary, result); + } else { + // write a zero vector for the overspill + Arrays.fill(binary, (byte) 0); + OptimizedScalarQuantizer.QuantizationResult zeroResult = new OptimizedScalarQuantizer.QuantizationResult(0f, 0f, 0f, 0); + writeQuantizedValue(quantizedVectorsTemp, binary, zeroResult); + } + } + // close the temporary file so we can read it later + quantizedVectorsTemp.close(); + success = true; + } finally { + if (success == false && quantizedVectorsTemp != null) { + mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); + } + } + // now we can read the quantized vectors from the temporary file + try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { + final long[] offsets = new long[centroidSupplier.size()]; + OffHeapQuantizedVectors offHeapQuantizedVectors = new OffHeapQuantizedVectors( + quantizedVectorsInput, + fieldInfo.getVectorDimension() + ); + DocIdsWriter docIdsWriter = new DocIdsWriter(); + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( + ES91OSQVectorsScorer.BULK_SIZE, + quantizer, + floatVectorValues, + postingsOutput + ); + for (int c = 0; c < centroidSupplier.size(); c++) { + float[] centroid = centroidSupplier.centroid(c); + // TODO: add back in sorting vectors by distance to centroid + int[] cluster = assignmentsByCluster[c]; + // TODO align??? + offsets[c] = postingsOutput.getFilePointer(); + int size = cluster.length; + postingsOutput.writeVInt(size); + postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + // TODO we might want to consider putting the docIds in a separate file + // to aid with only having to fetch vectors from slower storage when they are required + // keeping them in the same file indicates we pull the entire file into cache + docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); + bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); + } + + if (logger.isDebugEnabled()) { + printClusterQualityStatistics(assignmentsByCluster); + } + return offsets; + } + } + private static void printClusterQualityStatistics(int[][] clusters) { float min = Float.MAX_VALUE; float max = Float.MIN_VALUE; @@ -210,33 +324,7 @@ static CentroidAssignments buildCentroidAssignments(KMeansResult kMeansResult) { float[][] centroids = kMeansResult.centroids(); int[] assignments = kMeansResult.assignments(); int[] soarAssignments = kMeansResult.soarAssignments(); - int[] centroidVectorCount = new int[centroids.length]; - for (int i = 0; i < assignments.length; i++) { - centroidVectorCount[assignments[i]]++; - // if soar assignments are present, count them as well - if (soarAssignments.length > i && soarAssignments[i] != -1) { - centroidVectorCount[soarAssignments[i]]++; - } - } - - int[][] assignmentsByCluster = new int[centroids.length][]; - for (int c = 0; c < centroids.length; c++) { - assignmentsByCluster[c] = new int[centroidVectorCount[c]]; - } - Arrays.fill(centroidVectorCount, 0); - - for (int i = 0; i < assignments.length; i++) { - int c = assignments[i]; - assignmentsByCluster[c][centroidVectorCount[c]++] = i; - // if soar assignments are present, add them to the cluster as well - if (soarAssignments.length > i) { - int s = soarAssignments[i]; - if (s != -1) { - assignmentsByCluster[s][centroidVectorCount[s]++] = i; - } - } - } - return new CentroidAssignments(centroids, assignmentsByCluster); + return new CentroidAssignments(centroids, assignments, soarAssignments); } static void writeQuantizedValue(IndexOutput indexOutput, byte[] binaryValue, OptimizedScalarQuantizer.QuantizationResult corrections) @@ -281,4 +369,48 @@ public float[] centroid(int centroidOrdinal) throws IOException { return scratch; } } + + static class OffHeapQuantizedVectors { + private final IndexInput quantizedVectorsInput; + private final byte[] binaryScratch; + private final float[] corrections = new float[3]; + + private final int vectorByteSize; + private short bitSum; + private int currOrd = -1; + private boolean isOverspill = false; + + OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { + this.quantizedVectorsInput = quantizedVectorsInput; + this.binaryScratch = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); + } + + byte[] getVector(int ord, boolean isOverspill) throws IOException { + readQuantizedVector(ord, isOverspill); + return binaryScratch; + } + + OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call readQuantizedVector first"); + } + return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum); + } + + public void readQuantizedVector(int ord, boolean isOverspill) throws IOException { + if (ord == currOrd && isOverspill == this.isOverspill) { + return; // no need to read again + } + long offset = (long) ord * (vectorByteSize * 2) + (isOverspill ? vectorByteSize : 0); + quantizedVectorsInput.seek(offset); + quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length); + quantizedVectorsInput.readFloats(corrections, 0, 3); + bitSum = quantizedVectorsInput.readShort(); + if (ord != currOrd) { + currOrd = ord; + } + this.isOverspill = isOverspill; + } + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java index e6da0ae1caff0..be7a60a3db893 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/IVFVectorsWriter.java @@ -139,7 +139,18 @@ abstract long[] buildAndWritePostingsLists( CentroidSupplier centroidSupplier, FloatVectorValues floatVectorValues, IndexOutput postingsOutput, - int[][] assignmentsByCluster + int[] assignments, + int[] overspillAssignments + ) throws IOException; + + abstract long[] buildAndWritePostingsLists( + FieldInfo fieldInfo, + CentroidSupplier centroidSupplier, + FloatVectorValues floatVectorValues, + IndexOutput postingsOutput, + MergeState mergeState, + int[] assignments, + int[] overspillAssignments ) throws IOException; abstract CentroidSupplier createCentroidSupplier( @@ -174,7 +185,8 @@ public final void flush(int maxDoc, Sorter.DocMap sortMap) throws IOException { centroidSupplier, floatVectorValues, ivfClusters, - centroidAssignments.assignmentsByCluster() + centroidAssignments.assignments(), + centroidAssignments.overspillAssignments() ); // write posting lists writeMeta(fieldWriter.fieldInfo, centroidOffset, centroidLength, offsets, globalCentroid); @@ -284,7 +296,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws final long centroidOffset; final long centroidLength; final int numCentroids; - final int[][] assignmentsByCluster; + final int[] assignments; + final int[] overspillAssignments; final float[] calculatedGlobalCentroid = new float[fieldInfo.getVectorDimension()]; String centroidTempName = null; IndexOutput centroidTemp = null; @@ -300,7 +313,8 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws calculatedGlobalCentroid ); numCentroids = centroidAssignments.numCentroids(); - assignmentsByCluster = centroidAssignments.assignmentsByCluster(); + assignments = centroidAssignments.assignments(); + overspillAssignments = centroidAssignments.overspillAssignments(); success = true; } finally { if (success == false && centroidTempName != null) { @@ -337,7 +351,9 @@ private void mergeOneFieldIVF(FieldInfo fieldInfo, MergeState mergeState) throws centroidSupplier, floatVectorValues, ivfClusters, - assignmentsByCluster + mergeState, + assignments, + overspillAssignments ); assert offsets.length == centroidSupplier.size(); writeMeta(fieldInfo, centroidOffset, centroidLength, offsets, calculatedGlobalCentroid); From 4fcc6f86393b61381db657edb0ace1fb90e70e09 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 8 Jul 2025 09:50:47 -0400 Subject: [PATCH 2/5] iter --- .../vectors/DefaultIVFVectorsWriter.java | 167 +++++++++++++++--- .../codec/vectors/DiskBBQBulkWriter.java | 47 ++--- 2 files changed, 153 insertions(+), 61 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 32b69246d71da..8062c0632f0ff 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -18,6 +18,7 @@ import org.apache.lucene.store.IndexInput; import org.apache.lucene.store.IndexOutput; import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.hnsw.IntToIntFunction; import org.elasticsearch.index.codec.vectors.cluster.HierarchicalKMeans; import org.elasticsearch.index.codec.vectors.cluster.KMeansResult; import org.elasticsearch.logging.LogManager; @@ -81,28 +82,27 @@ long[] buildAndWritePostingsLists( } // write the posting lists final long[] offsets = new long[centroidSupplier.size()]; - OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); DocIdsWriter docIdsWriter = new DocIdsWriter(); - DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( - ES91OSQVectorsScorer.BULK_SIZE, - quantizer, + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); + OnHeapQuantizedVectors onHeapQuantizedVectors = new OnHeapQuantizedVectors( floatVectorValues, - postingsOutput + fieldInfo.getVectorDimension(), + new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()) ); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); - // TODO: add back in sorting vectors by distance to centroid int[] cluster = assignmentsByCluster[c]; // TODO align??? offsets[c] = postingsOutput.getFilePointer(); int size = cluster.length; postingsOutput.writeVInt(size); postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + onHeapQuantizedVectors.reset(centroid, size, ord -> cluster[ord]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); + bulkWriter.writeVectors(onHeapQuantizedVectors); } if (logger.isDebugEnabled()) { @@ -161,6 +161,35 @@ long[] buildAndWritePostingsLists( mergeState.segmentInfo.dir.deleteFile(quantizedVectorsTemp.getName()); } } + int[] centroidVectorCount = new int[centroidSupplier.size()]; + for (int i = 0; i < assignments.length; i++) { + centroidVectorCount[assignments[i]]++; + // if soar assignments are present, count them as well + if (overspillAssignments.length > i && overspillAssignments[i] != -1) { + centroidVectorCount[overspillAssignments[i]]++; + } + } + + int[][] assignmentsByCluster = new int[centroidSupplier.size()][]; + boolean[][] isOverspillByCluster = new boolean[centroidSupplier.size()][]; + for (int c = 0; c < centroidSupplier.size(); c++) { + assignmentsByCluster[c] = new int[centroidVectorCount[c]]; + isOverspillByCluster[c] = new boolean[centroidVectorCount[c]]; + } + Arrays.fill(centroidVectorCount, 0); + + for (int i = 0; i < assignments.length; i++) { + int c = assignments[i]; + assignmentsByCluster[c][centroidVectorCount[c]++] = i; + // if soar assignments are present, add them to the cluster as well + if (overspillAssignments.length > i) { + int s = overspillAssignments[i]; + if (s != -1) { + assignmentsByCluster[s][centroidVectorCount[s]] = i; + isOverspillByCluster[s][centroidVectorCount[s]++] = true; + } + } + } // now we can read the quantized vectors from the temporary file try (IndexInput quantizedVectorsInput = mergeState.segmentInfo.dir.openInput(quantizedVectorsTempName, IOContext.DEFAULT)) { final long[] offsets = new long[centroidSupplier.size()]; @@ -169,26 +198,22 @@ long[] buildAndWritePostingsLists( fieldInfo.getVectorDimension() ); DocIdsWriter docIdsWriter = new DocIdsWriter(); - DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter( - ES91OSQVectorsScorer.BULK_SIZE, - quantizer, - floatVectorValues, - postingsOutput - ); + DiskBBQBulkWriter bulkWriter = new DiskBBQBulkWriter.OneBitDiskBBQBulkWriter(ES91OSQVectorsScorer.BULK_SIZE, postingsOutput); for (int c = 0; c < centroidSupplier.size(); c++) { float[] centroid = centroidSupplier.centroid(c); - // TODO: add back in sorting vectors by distance to centroid int[] cluster = assignmentsByCluster[c]; + boolean[] isOverspill = isOverspillByCluster[c]; // TODO align??? offsets[c] = postingsOutput.getFilePointer(); int size = cluster.length; postingsOutput.writeVInt(size); postingsOutput.writeInt(Float.floatToIntBits(VectorUtil.dotProduct(centroid, centroid))); + offHeapQuantizedVectors.reset(size, ord -> isOverspill[ord], ord -> cluster[ord]); // TODO we might want to consider putting the docIds in a separate file // to aid with only having to fetch vectors from slower storage when they are required // keeping them in the same file indicates we pull the entire file into cache docIdsWriter.writeDocIds(j -> floatVectorValues.ordToDoc(cluster[j]), size, postingsOutput); - bulkWriter.writeOrds(j -> cluster[j], cluster.length, centroid); + bulkWriter.writeVectors(offHeapQuantizedVectors); } if (logger.isDebugEnabled()) { @@ -370,7 +395,72 @@ public float[] centroid(int centroidOrdinal) throws IOException { } } - static class OffHeapQuantizedVectors { + interface QuantizedVectorValues { + int count(); + + byte[] next() throws IOException; + + OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException; + } + + interface IntToBooleanFunction { + boolean apply(int ord); + } + + static class OnHeapQuantizedVectors implements QuantizedVectorValues { + private final FloatVectorValues vectorValues; + private final OptimizedScalarQuantizer quantizer; + private final byte[] quantizedVector; + private final int[] quantizedVectorScratch; + private OptimizedScalarQuantizer.QuantizationResult corrections; + private float[] currentCentroid; + private IntToIntFunction ordTransformer = null; + private int currOrd = -1; + private int count; + + OnHeapQuantizedVectors(FloatVectorValues vectorValues, int dimension, OptimizedScalarQuantizer quantizer) { + this.vectorValues = vectorValues; + this.quantizer = quantizer; + this.quantizedVector = new byte[BQVectorUtils.discretize(dimension, 64) / 8]; + this.quantizedVectorScratch = new int[dimension]; + this.corrections = null; + } + + private void reset(float[] centroid, int count, IntToIntFunction ordTransformer) { + this.currentCentroid = centroid; + this.ordTransformer = ordTransformer; + this.currOrd = -1; + this.count = count; + } + + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count() - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count()); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + float[] vector = vectorValues.vectorValue(ord); + corrections = quantizer.scalarQuantize(vector, quantizedVectorScratch, (byte) 1, currentCentroid); + BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector); + return quantizedVector; + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + if (currOrd == -1) { + throw new IllegalStateException("No vector read yet, call next first"); + } + return corrections; + } + } + + static class OffHeapQuantizedVectors implements QuantizedVectorValues { private final IndexInput quantizedVectorsInput; private final byte[] binaryScratch; private final float[] corrections = new float[3]; @@ -378,7 +468,9 @@ static class OffHeapQuantizedVectors { private final int vectorByteSize; private short bitSum; private int currOrd = -1; - private boolean isOverspill = false; + private int count; + private IntToBooleanFunction isOverspill = null; + private IntToIntFunction ordTransformer = null; OffHeapQuantizedVectors(IndexInput quantizedVectorsInput, int dimension) { this.quantizedVectorsInput = quantizedVectorsInput; @@ -386,31 +478,48 @@ static class OffHeapQuantizedVectors { this.vectorByteSize = (binaryScratch.length + 3 * Float.BYTES + Short.BYTES); } - byte[] getVector(int ord, boolean isOverspill) throws IOException { - readQuantizedVector(ord, isOverspill); - return binaryScratch; + private void reset(int count, IntToBooleanFunction isOverspill, IntToIntFunction ordTransformer) { + this.count = count; + this.isOverspill = isOverspill; + this.ordTransformer = ordTransformer; + this.currOrd = -1; } - OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { + @Override + public int count() { + return count; + } + + @Override + public byte[] next() throws IOException { + if (currOrd >= count - 1) { + throw new IllegalStateException("No more vectors to read, current ord: " + currOrd + ", count: " + count); + } + currOrd++; + int ord = ordTransformer.apply(currOrd); + boolean isOverspill = this.isOverspill.apply(currOrd); + return getVector(ord, isOverspill); + } + + @Override + public OptimizedScalarQuantizer.QuantizationResult getCorrections() throws IOException { if (currOrd == -1) { throw new IllegalStateException("No vector read yet, call readQuantizedVector first"); } return new OptimizedScalarQuantizer.QuantizationResult(corrections[0], corrections[1], corrections[2], bitSum); } + byte[] getVector(int ord, boolean isOverspill) throws IOException { + readQuantizedVector(ord, isOverspill); + return binaryScratch; + } + public void readQuantizedVector(int ord, boolean isOverspill) throws IOException { - if (ord == currOrd && isOverspill == this.isOverspill) { - return; // no need to read again - } - long offset = (long) ord * (vectorByteSize * 2) + (isOverspill ? vectorByteSize : 0); + long offset = (long) ord * (vectorByteSize * 2L) + (isOverspill ? vectorByteSize : 0); quantizedVectorsInput.seek(offset); quantizedVectorsInput.readBytes(binaryScratch, 0, binaryScratch.length); quantizedVectorsInput.readFloats(corrections, 0, 3); bitSum = quantizedVectorsInput.readShort(); - if (ord != currOrd) { - currOrd = ord; - } - this.isOverspill = isOverspill; } } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java index 6974cd50d4abc..662878270ea09 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DiskBBQBulkWriter.java @@ -9,34 +9,25 @@ package org.elasticsearch.index.codec.vectors; -import org.apache.lucene.index.FloatVectorValues; import org.apache.lucene.store.IndexOutput; -import org.apache.lucene.util.hnsw.IntToIntFunction; import java.io.IOException; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.discretize; -import static org.elasticsearch.index.codec.vectors.BQVectorUtils.packAsBinary; - /** * Base class for bulk writers that write vectors to disk using the BBQ encoding. * This class provides the structure for writing vectors in bulk, with specific * implementations for different bit sizes strategies. */ -public abstract class DiskBBQBulkWriter { +abstract class DiskBBQBulkWriter { protected final int bulkSize; - protected final OptimizedScalarQuantizer quantizer; protected final IndexOutput out; - protected final FloatVectorValues fvv; - protected DiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { + protected DiskBBQBulkWriter(int bulkSize, IndexOutput out) { this.bulkSize = bulkSize; - this.quantizer = quantizer; this.out = out; - this.fvv = fvv; } - public abstract void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException; + abstract void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException; private static void writeCorrections(OptimizedScalarQuantizer.QuantizationResult[] corrections, IndexOutput out) throws IOException { for (OptimizedScalarQuantizer.QuantizationResult correction : corrections) { @@ -64,39 +55,31 @@ private static void writeCorrection(OptimizedScalarQuantizer.QuantizationResult out.writeShort((short) targetComponentSum); } - public static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { - private final byte[] binarized; - private final int[] initQuantized; + static class OneBitDiskBBQBulkWriter extends DiskBBQBulkWriter { private final OptimizedScalarQuantizer.QuantizationResult[] corrections; - public OneBitDiskBBQBulkWriter(int bulkSize, OptimizedScalarQuantizer quantizer, FloatVectorValues fvv, IndexOutput out) { - super(bulkSize, quantizer, fvv, out); - this.binarized = new byte[discretize(fvv.dimension(), 64) / 8]; - this.initQuantized = new int[fvv.dimension()]; + OneBitDiskBBQBulkWriter(int bulkSize, IndexOutput out) { + super(bulkSize, out); this.corrections = new OptimizedScalarQuantizer.QuantizationResult[bulkSize]; } @Override - public void writeOrds(IntToIntFunction ords, int count, float[] centroid) throws IOException { - int limit = count - bulkSize + 1; + void writeVectors(DefaultIVFVectorsWriter.QuantizedVectorValues qvv) throws IOException { + int limit = qvv.count() - bulkSize + 1; int i = 0; for (; i < limit; i += bulkSize) { for (int j = 0; j < bulkSize; j++) { - int ord = ords.apply(i + j); - float[] fv = fvv.vectorValue(ord); - corrections[j] = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); - packAsBinary(initQuantized, binarized); - out.writeBytes(binarized, binarized.length); + byte[] qv = qvv.next(); + corrections[j] = qvv.getCorrections(); + out.writeBytes(qv, qv.length); } writeCorrections(corrections, out); } // write tail - for (; i < count; ++i) { - int ord = ords.apply(i); - float[] fv = fvv.vectorValue(ord); - OptimizedScalarQuantizer.QuantizationResult correction = quantizer.scalarQuantize(fv, initQuantized, (byte) 1, centroid); - packAsBinary(initQuantized, binarized); - out.writeBytes(binarized, binarized.length); + for (; i < qvv.count(); ++i) { + byte[] qv = qvv.next(); + OptimizedScalarQuantizer.QuantizationResult correction = qvv.getCorrections(); + out.writeBytes(qv, qv.length); writeCorrection(correction, out); } } From 4422f85cfaf0c667a35aaad90a63826f6b89f499 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 8 Jul 2025 14:41:53 -0400 Subject: [PATCH 3/5] iter --- .../index/codec/vectors/DefaultIVFVectorsWriter.java | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java index 8062c0632f0ff..e94b728f934e4 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java @@ -132,18 +132,25 @@ long[] buildAndWritePostingsLists( OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction()); int[] quantized = new int[fieldInfo.getVectorDimension()]; byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8]; + float[] overspillScratch = new float[fieldInfo.getVectorDimension()]; for (int i = 0; i < assignments.length; i++) { int c = assignments[i]; float[] centroid = centroidSupplier.centroid(c); float[] vector = floatVectorValues.vectorValue(i); + boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1; + // if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector + // so, make a copy of the vector to avoid mutating it + if (overspill) { + System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension()); + } + OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid); BQVectorUtils.packAsBinary(quantized, binary); writeQuantizedValue(quantizedVectorsTemp, binary, result); - boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1; if (overspill) { int s = overspillAssignments[i]; // write the overspill vector as well - result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroidSupplier.centroid(s)); + result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s)); BQVectorUtils.packAsBinary(quantized, binary); writeQuantizedValue(quantizedVectorsTemp, binary, result); } else { From fbfc9d369c4dcb14812089c1f23ad3e2de143513 Mon Sep 17 00:00:00 2001 From: Benjamin Trent <4357155+benwtrent@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:24:27 -0400 Subject: [PATCH 4/5] fixing assertion --- .../elasticsearch/index/codec/vectors/CentroidAssignments.java | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java index b76c78f7f4284..72f6494503dc2 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -13,6 +13,7 @@ record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignme CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { this(centroids.length, centroids, assignments, overspillAssignments); - assert assignments.length == overspillAssignments.length : "assignments and overspillAssignments must have the same length"; + assert assignments.length == overspillAssignments.length + || overspillAssignments.length == 0 : "assignments and overspillAssignments must have the same length"; } } From da2aec6c07804a624d26ed1c2a7b8f52f7ebf7ce Mon Sep 17 00:00:00 2001 From: elasticsearchmachine Date: Tue, 15 Jul 2025 20:34:12 +0000 Subject: [PATCH 5/5] [CI] Auto commit changes from spotless --- .../index/codec/vectors/CentroidAssignments.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java index 72f6494503dc2..e92ece41077a6 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/CentroidAssignments.java @@ -13,7 +13,7 @@ record CentroidAssignments(int numCentroids, float[][] centroids, int[] assignme CentroidAssignments(float[][] centroids, int[] assignments, int[] overspillAssignments) { this(centroids.length, centroids, assignments, overspillAssignments); - assert assignments.length == overspillAssignments.length - || overspillAssignments.length == 0 : "assignments and overspillAssignments must have the same length"; + assert assignments.length == overspillAssignments.length || overspillAssignments.length == 0 + : "assignments and overspillAssignments must have the same length"; } }