Skip to content

Commit baa97c7

Browse files
authored
Do not mutate the input vector in OptimizedScalarQuantizer (elastic#134472)
1 parent ad5df9d commit baa97c7

File tree

12 files changed

+131
-68
lines changed

12 files changed

+131
-68
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class OptimizedScalarQuantizerBenchmark {
4242
int dims;
4343

4444
float[] vector;
45+
float[] scratch;
4546
float[] centroid;
4647
int[] destination;
4748

@@ -57,6 +58,7 @@ public void init() {
5758
destination = new int[dims];
5859
vector = new float[dims];
5960
centroid = new float[dims];
61+
scratch = new float[dims];
6062
for (int i = 0; i < dims; ++i) {
6163
vector[i] = random.nextFloat();
6264
centroid[i] = random.nextFloat();
@@ -65,14 +67,14 @@ public void init() {
6567

6668
@Benchmark
6769
public int[] scalar() {
68-
osq.scalarQuantize(vector, destination, bits, centroid);
70+
osq.scalarQuantize(vector, scratch, destination, bits, centroid);
6971
return destination;
7072
}
7173

7274
@Benchmark
7375
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
7476
public int[] vector() {
75-
osq.scalarQuantize(vector, destination, bits, centroid);
77+
osq.scalarQuantize(vector, scratch, destination, bits, centroid);
7678
return destination;
7779
}
7880
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91Int4VectorScorerTests.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ public void testInt4ScoreBulk() throws Exception {
145145
}
146146

147147
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(similarityFunction);
148+
float[] scratch = new float[dimensions];
148149
try (Directory dir = new MMapDirectory(createTempDir())) {
149150
try (IndexOutput out = dir.createOutput("tests.bin", IOContext.DEFAULT)) {
150151
OptimizedScalarQuantizer.QuantizationResult[] results =
@@ -157,7 +158,7 @@ public void testInt4ScoreBulk() throws Exception {
157158
if (similarityFunction != VectorSimilarityFunction.EUCLIDEAN) {
158159
VectorUtil.l2normalize(vectors[i + j]);
159160
}
160-
results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 4, centroid);
161+
results[j] = quantizer.scalarQuantize(vectors[i + j], scratch, quantizedScratch, (byte) 4, centroid);
161162
for (int k = 0; k < dimensions; k++) {
162163
quantizeVector[k] = (byte) quantizedScratch[k];
163164
}
@@ -175,7 +176,8 @@ public void testInt4ScoreBulk() throws Exception {
175176
VectorUtil.l2normalize(query);
176177
}
177178
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
178-
query.clone(),
179+
query,
180+
new float[dimensions],
179181
quantizedScratch,
180182
(byte) 4,
181183
centroid

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES91OSQVectorScorerTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ public void testScore() throws Exception {
6464
final int numVectors = random().nextInt(10, 50);
6565
float[][] vectors = new float[numVectors][dimensions];
6666
final int[] scratch = new int[dimensions];
67+
final float[] residualScratch = new float[dimensions];
6768
final byte[] qVector = new byte[length];
6869
final float[] centroid = new float[dimensions];
6970
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
@@ -78,7 +79,8 @@ public void testScore() throws Exception {
7879
for (float[] vector : vectors) {
7980
randomVector(vector, similarityFunction);
8081
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(
81-
vector.clone(),
82+
vector,
83+
residualScratch,
8284
scratch,
8385
(byte) 1,
8486
centroid
@@ -94,7 +96,8 @@ public void testScore() throws Exception {
9496
final float[] query = new float[dimensions];
9597
randomVector(query, similarityFunction);
9698
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
97-
query.clone(),
99+
query,
100+
residualScratch,
98101
scratch,
99102
(byte) 4,
100103
centroid
@@ -160,6 +163,7 @@ public void testScoreBulk() throws Exception {
160163
final int numVectors = ES91OSQVectorsScorer.BULK_SIZE * random().nextInt(1, 10);
161164
float[][] vectors = new float[numVectors][dimensions];
162165
final int[] scratch = new int[dimensions];
166+
final float[] residualScratch = new float[dimensions];
163167
final byte[] qVector = new byte[length];
164168
final float[] centroid = new float[dimensions];
165169
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
@@ -177,7 +181,7 @@ public void testScoreBulk() throws Exception {
177181
for (int i = 0; i < limit; i += ES91OSQVectorsScorer.BULK_SIZE) {
178182
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
179183
randomVector(vectors[i + j], similarityFunction);
180-
results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), scratch, (byte) 1, centroid);
184+
results[j] = quantizer.scalarQuantize(vectors[i + j], residualScratch, scratch, (byte) 1, centroid);
181185
BQVectorUtils.packAsBinary(scratch, qVector);
182186
out.writeBytes(qVector, 0, qVector.length);
183187
}
@@ -187,7 +191,8 @@ public void testScoreBulk() throws Exception {
187191
final float[] query = new float[dimensions];
188192
randomVector(query, similarityFunction);
189193
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
190-
query.clone(),
194+
query,
195+
residualScratch,
191196
scratch,
192197
(byte) 4,
193198
centroid

libs/simdvec/src/test/java/org/elasticsearch/simdvec/internal/vectorization/ES92Int7VectorScorerTests.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ public void testInt7Score() throws Exception {
8383
final int numVectors = random().nextInt(1, 100);
8484

8585
float[][] vectors = new float[numVectors][dimensions];
86+
final float[] residualScratch = new float[dimensions];
8687
final int[] scratch = new int[dimensions];
8788
final byte[] qVector = new byte[dimensions];
8889
final float[] centroid = new float[dimensions];
@@ -94,7 +95,8 @@ public void testInt7Score() throws Exception {
9495
for (float[] vector : vectors) {
9596
randomVector(vector, similarityFunction);
9697
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(
97-
vector.clone(),
98+
vector,
99+
residualScratch,
98100
scratch,
99101
(byte) 7,
100102
centroid
@@ -112,7 +114,8 @@ public void testInt7Score() throws Exception {
112114
final float[] query = new float[dimensions];
113115
randomVector(query, similarityFunction);
114116
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
115-
query.clone(),
117+
query,
118+
residualScratch,
116119
scratch,
117120
(byte) 7,
118121
centroid
@@ -170,6 +173,7 @@ public void testInt7ScoreBulk() throws Exception {
170173
final float[][] vectors = new float[numVectors][dimensions];
171174
final int[] quantizedScratch = new int[dimensions];
172175
final byte[] quantizeVector = new byte[dimensions];
176+
final float[] residualScratch = new float[dimensions];
173177
final float[] centroid = new float[dimensions];
174178
VectorSimilarityFunction similarityFunction = randomFrom(VectorSimilarityFunction.values());
175179
randomVector(centroid, similarityFunction);
@@ -182,7 +186,7 @@ public void testInt7ScoreBulk() throws Exception {
182186
for (int i = 0; i < numVectors; i += ES91Int4VectorsScorer.BULK_SIZE) {
183187
for (int j = 0; j < ES91Int4VectorsScorer.BULK_SIZE; j++) {
184188
randomVector(vectors[i + j], similarityFunction);
185-
results[j] = quantizer.scalarQuantize(vectors[i + j].clone(), quantizedScratch, (byte) 7, centroid);
189+
results[j] = quantizer.scalarQuantize(vectors[i + j], residualScratch, quantizedScratch, (byte) 7, centroid);
186190
for (int k = 0; k < dimensions; k++) {
187191
quantizeVector[k] = (byte) quantizedScratch[k];
188192
}
@@ -195,7 +199,8 @@ public void testInt7ScoreBulk() throws Exception {
195199
final byte[] quantizeQuery = new byte[dimensions];
196200
randomVector(query, similarityFunction);
197201
OptimizedScalarQuantizer.QuantizationResult queryCorrections = quantizer.scalarQuantize(
198-
query.clone(),
202+
query,
203+
residualScratch,
199204
quantizedScratch,
200205
(byte) 7,
201206
centroid

server/src/main/java/org/elasticsearch/index/codec/vectors/OptimizedScalarQuantizer.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,20 @@ public OptimizedScalarQuantizer(VectorSimilarityFunction similarityFunction) {
5757

5858
public record QuantizationResult(float lowerInterval, float upperInterval, float additionalCorrection, int quantizedComponentSum) {}
5959

60-
public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinations, byte[] bits, float[] centroid) {
60+
public QuantizationResult[] multiScalarQuantize(
61+
float[] vector,
62+
float[] residualDestination,
63+
int[][] destinations,
64+
byte[] bits,
65+
float[] centroid
66+
) {
6167
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
6268
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
6369
assert bits.length == destinations.length;
6470
if (similarityFunction == EUCLIDEAN) {
65-
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
71+
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, residualDestination, statsScratch);
6672
} else {
67-
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
73+
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, residualDestination, statsScratch);
6874
}
6975
float vecMean = statsScratch[0];
7076
float vecVar = statsScratch[1];
@@ -78,14 +84,14 @@ public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinat
7884
int points = (1 << bits[i]);
7985
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
8086
initInterval(bits[i], vecStd, vecMean, min, max, intervalScratch);
81-
boolean hasQuantization = optimizeIntervals(intervalScratch, destinations[i], vector, norm2, points);
87+
boolean hasQuantization = optimizeIntervals(intervalScratch, destinations[i], residualDestination, norm2, points);
8288
// Now we have the optimized intervals, quantize the vector
8389
int sumQuery;
8490
if (hasQuantization) {
8591
sumQuery = getSumQuery(destinations[i]);
8692
} else {
8793
sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
88-
vector,
94+
residualDestination,
8995
destinations[i],
9096
intervalScratch[0],
9197
intervalScratch[1],
@@ -102,16 +108,16 @@ public QuantizationResult[] multiScalarQuantize(float[] vector, int[][] destinat
102108
return results;
103109
}
104110

105-
public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte bits, float[] centroid) {
111+
public QuantizationResult scalarQuantize(float[] vector, float[] residualDestination, int[] destination, byte bits, float[] centroid) {
106112
assert similarityFunction != COSINE || VectorUtil.isUnitVector(vector);
107113
assert similarityFunction != COSINE || VectorUtil.isUnitVector(centroid);
108114
assert vector.length <= destination.length;
109115
assert bits > 0 && bits <= 8;
110116
int points = 1 << bits;
111117
if (similarityFunction == EUCLIDEAN) {
112-
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, vector, statsScratch);
118+
ESVectorUtil.centerAndCalculateOSQStatsEuclidean(vector, centroid, residualDestination, statsScratch);
113119
} else {
114-
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, vector, statsScratch);
120+
ESVectorUtil.centerAndCalculateOSQStatsDp(vector, centroid, residualDestination, statsScratch);
115121
}
116122
float vecMean = statsScratch[0];
117123
float vecVar = statsScratch[1];
@@ -121,13 +127,19 @@ public QuantizationResult scalarQuantize(float[] vector, int[] destination, byte
121127
float vecStd = (float) Math.sqrt(vecVar);
122128
// Linearly scale the interval to the standard deviation of the vector, ensuring we are within the min/max bounds
123129
initInterval(bits, vecStd, vecMean, min, max, intervalScratch);
124-
boolean hasQuantization = optimizeIntervals(intervalScratch, destination, vector, norm2, points);
130+
boolean hasQuantization = optimizeIntervals(intervalScratch, destination, residualDestination, norm2, points);
125131
// Now we have the optimized intervals, quantize the vector
126132
int sumQuery;
127133
if (hasQuantization) {
128134
sumQuery = getSumQuery(destination);
129135
} else {
130-
sumQuery = ESVectorUtil.quantizeVectorWithIntervals(vector, destination, intervalScratch[0], intervalScratch[1], bits);
136+
sumQuery = ESVectorUtil.quantizeVectorWithIntervals(
137+
residualDestination,
138+
destination,
139+
intervalScratch[0],
140+
intervalScratch[1],
141+
bits
142+
);
131143
}
132144
return new QuantizationResult(
133145
intervalScratch[0],

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsReader.java

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import org.apache.lucene.index.VectorSimilarityFunction;
1616
import org.apache.lucene.search.KnnCollector;
1717
import org.apache.lucene.store.IndexInput;
18-
import org.apache.lucene.util.ArrayUtil;
1918
import org.apache.lucene.util.Bits;
2019
import org.apache.lucene.util.VectorUtil;
2120
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
@@ -94,12 +93,9 @@ CentroidIterator getCentroidIterator(
9493
final float globalCentroidDp = fieldEntry.globalCentroidDp();
9594
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
9695
final int[] scratch = new int[targetQuery.length];
97-
float[] targetQueryCopy = ArrayUtil.copyArray(targetQuery);
98-
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
99-
VectorUtil.l2normalize(targetQueryCopy);
100-
}
10196
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
102-
targetQueryCopy,
97+
targetQuery,
98+
new float[targetQuery.length],
10399
scratch,
104100
(byte) 7,
105101
fieldEntry.globalCentroid()
@@ -584,11 +580,8 @@ public int visit(KnnCollector knnCollector) throws IOException {
584580

585581
private void quantizeQueryIfNecessary() {
586582
if (quantized == false) {
587-
System.arraycopy(target, 0, scratch, 0, target.length);
588-
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
589-
VectorUtil.l2normalize(scratch);
590-
}
591-
queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid);
583+
assert fieldInfo.getVectorSimilarityFunction() != COSINE || VectorUtil.isUnitVector(target);
584+
queryCorrections = quantizer.scalarQuantize(target, scratch, quantizationScratch, (byte) 4, centroid);
592585
transposeHalfByte(quantizationScratch, quantizedQueryScratch);
593586
quantized = true;
594587
}

server/src/main/java/org/elasticsearch/index/codec/vectors/diskbbq/ES920DiskBBQVectorsWriter.java

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -182,25 +182,25 @@ CentroidOffsetAndLength buildAndWritePostingsLists(
182182
OptimizedScalarQuantizer quantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
183183
int[] quantized = new int[fieldInfo.getVectorDimension()];
184184
byte[] binary = new byte[BQVectorUtils.discretize(fieldInfo.getVectorDimension(), 64) / 8];
185-
float[] overspillScratch = new float[fieldInfo.getVectorDimension()];
185+
float[] scratch = new float[fieldInfo.getVectorDimension()];
186186
for (int i = 0; i < assignments.length; i++) {
187187
int c = assignments[i];
188188
float[] centroid = centroidSupplier.centroid(c);
189189
float[] vector = floatVectorValues.vectorValue(i);
190190
boolean overspill = overspillAssignments.length > i && overspillAssignments[i] != -1;
191-
// if overspilling, this means we quantize twice, and quantization mutates the in-memory representation of the vector
192-
// so, make a copy of the vector to avoid mutating it
193-
if (overspill) {
194-
System.arraycopy(vector, 0, overspillScratch, 0, fieldInfo.getVectorDimension());
195-
}
196-
197-
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(vector, quantized, (byte) 1, centroid);
191+
OptimizedScalarQuantizer.QuantizationResult result = quantizer.scalarQuantize(
192+
vector,
193+
scratch,
194+
quantized,
195+
(byte) 1,
196+
centroid
197+
);
198198
BQVectorUtils.packAsBinary(quantized, binary);
199199
writeQuantizedValue(quantizedVectorsTemp, binary, result);
200200
if (overspill) {
201201
int s = overspillAssignments[i];
202202
// write the overspill vector as well
203-
result = quantizer.scalarQuantize(overspillScratch, quantized, (byte) 1, centroidSupplier.centroid(s));
203+
result = quantizer.scalarQuantize(vector, scratch, quantized, (byte) 1, centroidSupplier.centroid(s));
204204
BQVectorUtils.packAsBinary(quantized, binary);
205205
writeQuantizedValue(quantizedVectorsTemp, binary, result);
206206
} else {
@@ -629,10 +629,7 @@ public byte[] next() throws IOException {
629629
}
630630
currOrd++;
631631
float[] vector = supplier.centroid(ordTransformer.apply(currOrd));
632-
// Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice
633-
// due to overspill, so we copy the vector to a scratch array
634-
System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length);
635-
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 7, centroid);
632+
corrections = quantizer.scalarQuantize(vector, floatVectorScratch, quantizedVectorScratch, (byte) 7, centroid);
636633
for (int i = 0; i < quantizedVectorScratch.length; i++) {
637634
quantizedVector[i] = (byte) quantizedVectorScratch[i];
638635
}
@@ -686,10 +683,7 @@ public byte[] next() throws IOException {
686683
currOrd++;
687684
int ord = ordTransformer.apply(currOrd);
688685
float[] vector = vectorValues.vectorValue(ord);
689-
// Its possible that the vectors are on-heap and we cannot mutate them as we may quantize twice
690-
// due to overspill, so we copy the vector to a scratch array
691-
System.arraycopy(vector, 0, floatVectorScratch, 0, vector.length);
692-
corrections = quantizer.scalarQuantize(floatVectorScratch, quantizedVectorScratch, (byte) 1, currentCentroid);
686+
corrections = quantizer.scalarQuantize(vector, floatVectorScratch, quantizedVectorScratch, (byte) 1, currentCentroid);
693687
BQVectorUtils.packAsBinary(quantizedVectorScratch, quantizedVector);
694688
return quantizedVector;
695689
}

0 commit comments

Comments
 (0)