Skip to content

Commit 5420431

Browse files
committed
optimize OptimizedScalarQuantizer#scalarQuantize when destination can be an integer array
1 parent ffea6ca commit 5420431

File tree

12 files changed

+187
-6
lines changed

12 files changed

+187
-6
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ public class OptimizedScalarQuantizerBenchmark {
4444
float[] vector;
4545
float[] centroid;
4646
byte[] destination;
47+
int[] intDestination;
4748

4849
@Param({ "1", "4", "7" })
4950
byte bits;
@@ -55,6 +56,7 @@ public void init() {
5556
ThreadLocalRandom random = ThreadLocalRandom.current();
5657
// random byte arrays for binary methods
5758
destination = new byte[dims];
59+
intDestination = new int[dims];
5860
vector = new float[dims];
5961
centroid = new float[dims];
6062
for (int i = 0; i < dims; ++i) {
@@ -75,4 +77,11 @@ public byte[] vector() {
7577
osq.scalarQuantize(vector, destination, bits, centroid);
7678
return destination;
7779
}
80+
81+
@Benchmark
82+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
83+
public byte[] vectorToInt() {
84+
osq.scalarQuantizeToInts(vector, intDestination, bits, centroid);
85+
return destination;
86+
}
7887
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,4 +258,25 @@ public static float soarDistance(float[] v1, float[] centroid, float[] originalR
258258
}
259259
return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
260260
}
261+
262+
/**
263+
* Optimized-scalar quantization of the provided vector to the provided destination array.
264+
*
265+
* @param vector the vector to quantize
266+
* @param destination the array to store the result
267+
* @param lowInterval the minimum value, lower values in the original array will be replaced by this value
268+
* @param upperInterval the maximum value, bigger values in the original array will be replaced by this value
269+
* @param bit the number of bits to use for quantization, must be between 1 and 8
270+
*
271+
* @return return the sum of all the elements of the resulting quantized vector.
272+
*/
273+
public static int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bit) {
274+
if (vector.length != destination.length) {
275+
throw new IllegalArgumentException("vector dimensions differ: " + vector.length + "!=" + destination.length);
276+
}
277+
if (bit <= 0 || bit > Byte.SIZE) {
278+
throw new IllegalArgumentException("bit must be between 1 and 8, but was: " + bit);
279+
}
280+
return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
281+
}
261282
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,4 +269,18 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
269269
}
270270
return ret;
271271
}
272+
273+
@Override
274+
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
275+
float nSteps = ((1 << bits) - 1);
276+
float step = (upperInterval - lowInterval) / nSteps;
277+
int sumQuery = 0;
278+
for (int h = 0; h < vector.length; h++) {
279+
float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
280+
int assignment = Math.round((xi - lowInterval) / step);
281+
sumQuery += assignment;
282+
destination[h] = assignment;
283+
}
284+
return sumQuery;
285+
}
272286
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,6 @@ public interface ESVectorUtilSupport {
3939

4040
float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);
4141

42+
int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);
43+
4244
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -791,4 +791,32 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
791791

792792
return sum;
793793
}
794+
795+
@Override
796+
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
797+
float nSteps = ((1 << bits) - 1);
798+
float step = (upperInterval - lowInterval) / nSteps;
799+
int sumQuery = 0;
800+
int i = 0;
801+
if (vector.length > 2 * FLOAT_SPECIES.length()) {
802+
int limit = FLOAT_SPECIES.loopBound(vector.length);
803+
FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
804+
FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
805+
FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
806+
for (; i < limit; i += FLOAT_SPECIES.length()) {
807+
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
808+
FloatVector xi = v.max(lowVec).min(upperVec); // clamp
809+
IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
810+
sumQuery += assignment.reduceLanes(ADD);
811+
assignment.intoArray(destination, i);
812+
}
813+
}
814+
for (; i < vector.length; i++) {
815+
float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
816+
int assignment = Math.round((xi - lowInterval) / step);
817+
sumQuery += assignment;
818+
destination[i] = assignment;
819+
}
820+
return sumQuery;
821+
}
794822
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,29 @@ public void testSoarDistance() {
286286
assertEquals(expected, result, deltaEps);
287287
}
288288

289+
public void testQuantizeVectorWithIntervals() {
290+
int vectorSize = randomIntBetween(1, 2048);
291+
float[] vector = new float[vectorSize];
292+
293+
byte bits = (byte) randomIntBetween(1, 8);
294+
for (int i = 0; i < vectorSize; ++i) {
295+
vector[i] = random().nextFloat();
296+
}
297+
float low = random().nextFloat();
298+
float high = random().nextFloat();
299+
if (low > high) {
300+
float tmp = low;
301+
low = high;
302+
high = tmp;
303+
}
304+
byte[] quantizeExpected = new byte[vectorSize];
305+
byte[] quantizeResult = new byte[vectorSize];
306+
var expected = defaultedProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, low, high, quantizeExpected, bits);
307+
var result = defOrPanamaProvider.getVectorUtilSupport().quantizeVectorWithIntervals(vector, low, high, quantizeResult, bits);
308+
assertArrayEquals(quantizeExpected, quantizeResult);
309+
assertEquals(expected, result, 0f);
310+
}
311+
289312
void testIpByteBinImpl(ToLongBiFunction<byte[], byte[]> ipByteBinFunc) {
290313
int iterations = atLeast(50);
291314
for (int i = 0; i < iterations; i++) {

server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,34 @@ public static void transposeHalfByte(byte[] q, byte[] quantQueryByte) {
5757
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
5858
}
5959
}
60+
61+
/**
62+
* Same as {@link #transposeHalfByte(byte[], byte[])} but the input vector is provided as
63+
* an array of integers.This is useful when using {@link OptimizedScalarQuantizer#scalarQuantizeToInts(float[], int[], byte, float[])}
64+
* which performs better than the byte version.
65+
*
66+
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
67+
* @param quantQueryByte the byte array to store the transposed query vector
68+
* */
69+
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
70+
for (int i = 0; i < q.length;) {
71+
assert q[i] >= 0 && q[i] <= 15;
72+
int lowerByte = 0;
73+
int lowerMiddleByte = 0;
74+
int upperMiddleByte = 0;
75+
int upperByte = 0;
76+
for (int j = 7; j >= 0 && i < q.length; j--) {
77+
lowerByte |= (q[i] & 1) << j;
78+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
79+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
80+
upperByte |= ((q[i] >> 3) & 1) << j;
81+
i++;
82+
}
83+
int index = ((i + 7) / 8) - 1;
84+
quantQueryByte[index] = (byte) lowerByte;
85+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
86+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
87+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
88+
}
89+
}
6090
}

server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,20 @@ public static void packAsBinary(byte[] vector, byte[] packed) {
5454
}
5555
}
5656

57+
public static void packAsBinary(int[] vector, byte[] packed) {
58+
for (int i = 0; i < vector.length;) {
59+
byte result = 0;
60+
for (int j = 7; j >= 0 && i < vector.length; j--) {
61+
assert vector[i] == 0 || vector[i] == 1;
62+
result |= (byte) ((vector[i] & 1) << j);
63+
++i;
64+
}
65+
int index = ((i + 7) / 8) - 1;
66+
assert index < packed.length;
67+
packed[index] = result;
68+
}
69+
}
70+
5771
public static int discretize(int value, int bucket) {
5872
return ((value + (bucket - 1)) / bucket) * bucket;
5973
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsReader.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
182182
DocIdsWriter docIdsWriter = new DocIdsWriter();
183183

184184
final float[] scratch;
185-
final byte[] quantizationScratch;
185+
final int[] quantizationScratch;
186186
final byte[] quantizedQueryScratch;
187187
final OptimizedScalarQuantizer quantizer;
188188
final float[] correctiveValues = new float[3];
@@ -202,7 +202,7 @@ private static class MemorySegmentPostingsVisitor implements PostingVisitor {
202202
this.needsScoring = needsScoring;
203203

204204
scratch = new float[target.length];
205-
quantizationScratch = new byte[target.length];
205+
quantizationScratch = new int[target.length];
206206
final int discretizedDimensions = discretize(fieldInfo.getVectorDimension(), 64);
207207
quantizedQueryScratch = new byte[QUERY_BITS * discretizedDimensions / 8];
208208
quantizedByteLength = discretizedDimensions / 8 + (Float.BYTES * 3) + Short.BYTES;
@@ -344,7 +344,7 @@ private void quantizeQueryIfNecessary() {
344344
if (fieldInfo.getVectorSimilarityFunction() == COSINE) {
345345
VectorUtil.l2normalize(scratch);
346346
}
347-
queryCorrections = quantizer.scalarQuantize(scratch, quantizationScratch, (byte) 4, centroid);
347+
queryCorrections = quantizer.scalarQuantizeToInts(scratch, quantizationScratch, (byte) 4, centroid);
348348
transposeHalfByte(quantizationScratch, quantizedQueryScratch);
349349
quantized = true;
350350
}

server/src/main/java/org/elasticsearch/index/codec/vectors/DefaultIVFVectorsWriter.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ CentroidAssignments calculateAndWriteCentroids(
291291
static class BinarizedFloatVectorValues {
292292
private OptimizedScalarQuantizer.QuantizationResult corrections;
293293
private final byte[] binarized;
294-
private final byte[] initQuantized;
294+
private final int[] initQuantized;
295295
private float[] centroid;
296296
private final FloatVectorValues values;
297297
private final OptimizedScalarQuantizer quantizer;
@@ -302,7 +302,7 @@ static class BinarizedFloatVectorValues {
302302
this.values = delegate;
303303
this.quantizer = quantizer;
304304
this.binarized = new byte[discretize(delegate.dimension(), 64) / 8];
305-
this.initQuantized = new byte[delegate.dimension()];
305+
this.initQuantized = new int[delegate.dimension()];
306306
}
307307

308308
public OptimizedScalarQuantizer.QuantizationResult getCorrectiveTerms(int ord) {
@@ -323,7 +323,7 @@ public byte[] vectorValue(int ord) throws IOException {
323323
}
324324

325325
private void binarize(int ord) throws IOException {
326-
corrections = quantizer.scalarQuantize(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
326+
corrections = quantizer.scalarQuantizeToInts(values.vectorValue(ord), initQuantized, INDEX_BITS, centroid);
327327
packAsBinary(initQuantized, binarized);
328328
}
329329
}

0 commit comments

Comments
 (0)