Skip to content

Commit 42ebc0a

Browse files
authored
Merge branch 'main' into SEARCH-1080-yaml-test-failure-rrf-with-pinned-retriever-as-a-sub-retriever
2 parents e8cc803 + f81d355 commit 42ebc0a

File tree

17 files changed

+227
-51
lines changed

17 files changed

+227
-51
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ public class OptimizedScalarQuantizerBenchmark {
4343

4444
float[] vector;
4545
float[] centroid;
46-
byte[] destination;
46+
byte[] legacyDestination;
47+
int[] destination;
4748

4849
@Param({ "1", "4", "7" })
4950
byte bits;
@@ -54,7 +55,8 @@ public class OptimizedScalarQuantizerBenchmark {
5455
public void init() {
5556
ThreadLocalRandom random = ThreadLocalRandom.current();
5657
// random byte arrays for binary methods
57-
destination = new byte[dims];
58+
legacyDestination = new byte[dims];
59+
destination = new int[dims];
5860
vector = new float[dims];
5961
centroid = new float[dims];
6062
for (int i = 0; i < dims; ++i) {
@@ -65,13 +67,20 @@ public void init() {
6567

6668
@Benchmark
6769
public byte[] scalar() {
68-
osq.scalarQuantize(vector, destination, bits, centroid);
69-
return destination;
70+
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
71+
return legacyDestination;
72+
}
73+
74+
@Benchmark
75+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
76+
public byte[] legacyVector() {
77+
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
78+
return legacyDestination;
7079
}
7180

7281
@Benchmark
7382
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
74-
public byte[] vector() {
83+
public int[] vector() {
7584
osq.scalarQuantize(vector, destination, bits, centroid);
7685
return destination;
7786
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,25 @@ public static float soarDistance(float[] v1, float[] centroid, float[] originalR
258258
}
259259
return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
260260
}
261+
262+
/**
263+
* Optimized-scalar quantization of the provided vector to the provided destination array.
264+
*
265+
* @param vector the vector to quantize
266+
* @param destination the array to store the result
267+
* @param lowInterval the minimum value, lower values in the original array will be replaced by this value
268+
* @param upperInterval the maximum value, bigger values in the original array will be replaced by this value
269+
* @param bit the number of bits to use for quantization, must be between 1 and 8
270+
*
271+
* @return return the sum of all the elements of the resulting quantized vector.
272+
*/
273+
public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) {
274+
if (vector.length > destination.length) {
275+
throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length);
276+
}
277+
if (bit <= 0 || bit > Byte.SIZE) {
278+
throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit);
279+
}
280+
return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
281+
}
261282
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,4 +269,18 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
269269
}
270270
return ret;
271271
}
272+
273+
@Override
274+
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
275+
float nSteps = ((1 << bits) - 1);
276+
float step = (upperInterval - lowInterval) / nSteps;
277+
int sumQuery = 0;
278+
for (int h = 0; h < vector.length; h++) {
279+
float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
280+
int assignment = Math.round((xi - lowInterval) / step);
281+
sumQuery += assignment;
282+
destination[h] = assignment;
283+
}
284+
return sumQuery;
285+
}
272286
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,6 @@ public interface ESVectorUtilSupport {
3939

4040
float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);
4141

42+
int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);
43+
4244
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,4 +791,32 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
791791

792792
return sum;
793793
}
794+
795+
@Override
796+
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
797+
float nSteps = ((1 << bits) - 1);
798+
float step = (upperInterval - lowInterval) / nSteps;
799+
int sumQuery = 0;
800+
int i = 0;
801+
if (vector.length > 2 * FLOAT_SPECIES.length()) {
802+
int limit = FLOAT_SPECIES.loopBound(vector.length);
803+
FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
804+
FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
805+
FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
806+
for (; i < limit; i += FLOAT_SPECIES.length()) {
807+
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
808+
FloatVector xi = v.max(lowVec).min(upperVec); // clamp
809+
IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
810+
sumQuery += assignment.reduceLanes(ADD);
811+
assignment.intoArray(destination, i);
812+
}
813+
}
814+
for (; i < vector.length; i++) {
815+
float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
816+
int assignment = Math.round((xi - lowInterval) / step);
817+
sumQuery += assignment;
818+
destination[i] = assignment;
819+
}
820+
return sumQuery;
821+
}
794822
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,29 @@ public void testSoarDistance() {
286286
assertEquals(expected, result, deltaEps);
287287
}
288288

289+
public void testQuantizeVectorWithIntervals() {
290+
int vectorSize = randomIntBetween(1, 2048);
291+
float[] vector = new float[vectorSize];
292+
293+
byte bits = (byte) randomIntBetween(1, 8);
294+
for (int i = 0; i < vectorSize; ++i) {
295+
vector[i] = random().nextFloat();
296+
}
297+
float low = random().nextFloat();
298+
float high = random().nextFloat();
299+
if (low > high) {
300+
float tmp = low;
301+
low = high;
302+
high = tmp;
303+
}
304+
int[] quantizeExpected = new int[vectorSize];
305+
int[] quantizeResult = new int[vectorSize];
306+
var expected = defaultedProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeExpected, low, high, bits);
307+
var result = defOrPanamaProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, quantizeResult, low, high, bits);
308+
assertArrayEquals(quantizeExpected, quantizeResult);
309+
assertEquals(expected, result, 0f);
310+
}
311+
289312
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
290313
int iterations = atLeast(50);
291314
for (int i = 0; i < iterations; i++) {

server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,33 @@ public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) {
5757
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
5858
}
5959
}
60+
61+
/**
62+
* Same as {@link #transposeHalfByte(byte[], byte[])} but the input vector is provided as
63+
* an array of integers.
64+
*
65+
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
66+
* @param quantQueryByte the byte array to store the transposed query vector
67+
* */
68+
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
69+
for (int i = 0; i < q.length;) {
70+
assert q[i] >= 0 && q[i] <= 15;
71+
int lowerByte = 0;
72+
int lowerMiddleByte = 0;
73+
int upperMiddleByte = 0;
74+
int upperByte = 0;
75+
for (int j = 7; j >= 0 && i < q.length; j--) {
76+
lowerByte |= (q[i] & 1) << j;
77+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
78+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
79+
upperByte |= ((q[i] >> 3) & 1) << j;
80+
i++;
81+
}
82+
int index = ((i + 7) / 8) - 1;
83+
quantQueryByte[index] = (byte) lowerByte;
84+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
85+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
86+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
87+
}
88+
}
6089
}

server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public static boolean isUnitVector(float[] v) {
4040
return Math.abs(l1norm - 1.0d) <= EPSILON;
4141
}
4242

43-
public static void packAsBinary(byte[] vector, byte[] packed) {
43+
public static void packAsBinary(int[] vector, byte[] packed) {
4444
for (int i = 0; i < vector.length;) {
4545
byte result = 0;
4646
for (int j = 7; j >= 0 && i < vector.length; j--) {

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -52,13 +52,17 @@ CentroidQueryScorer getCentroidScorer(FieldInfo fieldInfo, int numCentroids, Ind
5252
final FieldEntry fieldEntry = fields.get(fieldInfo.number);
5353
final float globalCentroidDp = fieldEntry.globalCentroidDp();
5454
final OptimizedScalarQuantizer scalarQuantizer = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
55-
final byte[] quantized = new byte[targetQuery.length];
55+
final int[] scratch = new int[targetQuery.length];
5656
final OptimizedScalarQuantizer.QuantizationResult queryParams = scalarQuantizer.scalarQuantize(
5757
ArrayUtil.copyArray(targetQuery),
58-
quantized,
58+
scratch,
5959
(byte) 4,
6060
fieldEntry.globalCentroid()
6161
);
62+
final byte[] quantized = new byte[targetQuery.length];
63+
for (int i = 0; i < quantized.length; i++) {
64+
quantized[i] = (byte) scratch[i];
65+
}
6266
final ES91Int4VectorsScorer scorer = ESVectorUtil.getES91Int4VectorsScorer(centroids, fieldInfo.getVectorDimension());
6367
return new CentroidQueryScorer() {
6468
int currentCentroid = -1;
@@ -182,7 +186,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
182186
DocIdsWriter docIdsWriter = new DocIdsWriter();
183187

184188
final float[] scratch;
185-
final byte[] quantizationScratch;
189+
final int[] quantizationScratch;
186190
final byte[] quantizedQueryScratch;
187191
final OptimizedScalarQuantizer quantizer;
188192
final float[] correctiveValues = new float[3];
@@ -202,7 +206,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
202206
this.needsScoring = needsScoring;
203207

204208
scratch = new float[target.length];
205-
quantizationScratch = new byte[target.length];
209+
quantizationScratch = new int[target.length];
206210
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
207211
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
208212
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,9 @@ CentroidSupplier createCentroidSupplier(IndexInput centroidsInput, int numCentro
122122
static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] globalCentroid, IndexOutput centroidOutput)
123123
throws IOException {
124124
final OptimizedScalarQuantizer osq = new OptimizedScalarQuantizer(fieldInfo.getVectorSimilarityFunction());
125-
byte[] quantizedScratch = new byte[fieldInfo.getVectorDimension()];
125+
int[] quantizedScratch = new int[fieldInfo.getVectorDimension()];
126126
float[] centroidScratch = new float[fieldInfo.getVectorDimension()];
127+
final byte[] quantized = new byte[fieldInfo.getVectorDimension()];
127128
// TODO do we want to store these distances as well for future use?
128129
// TODO: sort centroids by global centroid (was doing so previously here)
129130
// TODO: sorting tanks recall possibly because centroids ordinals no longer are aligned
@@ -135,7 +136,10 @@ static void writeCentroids(float[][] centroids, FieldInfo fieldInfo, float[] glo
135136
(byte) 4,
136137
globalCentroid
137138
);
138-
writeQuantizedValue(centroidOutput, quantizedScratch, result);
139+
for (int i = 0; i < quantizedScratch.length; i++) {
140+
quantized[i] = (byte) quantizedScratch[i];
141+
}
142+
writeQuantizedValue(centroidOutput, quantized, result);
139143
}
140144
final ByteBuffer buffer = ByteBuffer.allocate(fieldInfo.getVectorDimension() * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN);
141145
for (float[] centroid : centroids) {

0 commit comments

Comments
 (0)