Skip to content

Commit 4468239

Browse files
authored
Speed up OptimizedScalarQuantizer (elastic#131599)
use the destination array to keep the quantize value during the loss computation and give to the method computing the grid points
1 parent f393dba commit 4468239

File tree

9 files changed

+147
-139
lines changed

9 files changed

+147
-139
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OptimizedScalarQuantizerBenchmark.java

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@ public class OptimizedScalarQuantizerBenchmark {
4343

4444
float[] vector;
4545
float[] centroid;
46-
byte[] legacyDestination;
4746
int[] destination;
4847

4948
@Param({ "1", "4", "7" })
@@ -55,7 +54,6 @@ public class OptimizedScalarQuantizerBenchmark {
5554
public void init() {
5655
ThreadLocalRandom random = ThreadLocalRandom.current();
5756
// random byte arrays for binary methods
58-
legacyDestination = new byte[dims];
5957
destination = new int[dims];
6058
vector = new float[dims];
6159
centroid = new float[dims];
@@ -66,16 +64,9 @@ public void init() {
6664
}
6765

6866
@Benchmark
69-
public byte[] scalar() {
70-
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
71-
return legacyDestination;
72-
}
73-
74-
@Benchmark
75-
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
76-
public byte[] legacyVector() {
77-
osq.legacyScalarQuantize(vector, legacyDestination, bits, centroid);
78-
return legacyDestination;
67+
public int[] scalar() {
68+
osq.scalarQuantize(vector, destination, bits, centroid);
69+
return destination;
7970
}
8071

8172
@Benchmark

docs/changelog/131599.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 131599
2+
summary: Speed up `OptimizedScalarQuantizer`
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -158,31 +158,41 @@ static int andBitCountLong(byte[] a, byte[] b) {
158158
/**
159159
* Calculate the loss for optimized-scalar quantization for the given parameteres
160160
* @param target The vector being quantized, assumed to be centered
161-
* @param interval The interval for which to calculate the loss
161+
* @param lowerInterval The lower interval value for which to calculate the loss
162+
* @param upperInterval The upper interval value for which to calculate the loss
162163
* @param points the quantization points
163164
* @param norm2 The norm squared of the target vector
164165
* @param lambda The lambda parameter for controlling anisotropic loss calculation
166+
* @param quantize array to store the computed quantize vector.
167+
*
165168
* @return The loss for the given parameters
166169
*/
167-
public static float calculateOSQLoss(float[] target, float[] interval, int points, float norm2, float lambda) {
168-
assert interval.length == 2;
169-
float step = ((interval[1] - interval[0]) / (points - 1.0F));
170+
public static float calculateOSQLoss(
171+
float[] target,
172+
float lowerInterval,
173+
float upperInterval,
174+
int points,
175+
float norm2,
176+
float lambda,
177+
int[] quantize
178+
) {
179+
assert upperInterval >= lowerInterval;
180+
float step = ((upperInterval - lowerInterval) / (points - 1.0F));
170181
float invStep = 1f / step;
171-
return IMPL.calculateOSQLoss(target, interval, step, invStep, norm2, lambda);
182+
return IMPL.calculateOSQLoss(target, lowerInterval, upperInterval, step, invStep, norm2, lambda, quantize);
172183
}
173184

174185
/**
175186
* Calculate the grid points for optimized-scalar quantization
176187
* @param target The vector being quantized, assumed to be centered
177-
* @param interval The interval for which to calculate the grid points
188+
* @param quantize The quantize vector which should have at least the target vector length
178189
* @param points the quantization points
179190
* @param pts The array to store the grid points, must be of length 5
180191
*/
181-
public static void calculateOSQGridPoints(float[] target, float[] interval, int points, float[] pts) {
182-
assert interval.length == 2;
192+
public static void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
193+
assert target.length <= quantize.length;
183194
assert pts.length == 5;
184-
float invStep = (points - 1.0F) / (interval[1] - interval[0]);
185-
IMPL.calculateOSQGridPoints(target, interval, points, invStep, pts);
195+
IMPL.calculateOSQGridPoints(target, quantize, points, pts);
186196
}
187197

188198
/**

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,25 @@ public float ipFloatByte(float[] q, byte[] d) {
4646
}
4747

4848
@Override
49-
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
50-
float a = interval[0];
51-
float b = interval[1];
49+
public float calculateOSQLoss(
50+
float[] target,
51+
float low,
52+
float high,
53+
float step,
54+
float invStep,
55+
float norm2,
56+
float lambda,
57+
int[] quantize
58+
) {
59+
float a = low;
60+
float b = high;
5261
float xe = 0f;
5362
float e = 0f;
54-
for (float xi : target) {
63+
for (int i = 0; i < target.length; ++i) {
64+
float xi = target[i];
5565
// this is quantizing and then dequantizing the vector
56-
float xiq = fma(step, Math.round((Math.min(Math.max(xi, a), b) - a) * invStep), a);
66+
quantize[i] = Math.round((Math.min(Math.max(xi, a), b) - a) * invStep);
67+
float xiq = fma(step, quantize[i], a);
5768
// how much does the de-quantized value differ from the original value
5869
float xiiq = xi - xiq;
5970
e = fma(xiiq, xiiq, e);
@@ -63,16 +74,15 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
6374
}
6475

6576
@Override
66-
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
67-
float a = interval[0];
68-
float b = interval[1];
77+
public void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
6978
float daa = 0;
7079
float dab = 0;
7180
float dbb = 0;
7281
float dax = 0;
7382
float dbx = 0;
74-
for (float v : target) {
75-
float k = Math.round((Math.min(Math.max(v, a), b) - a) * invStep);
83+
for (int i = 0; i < target.length; ++i) {
84+
float v = target[i];
85+
float k = quantize[i];
7686
float s = k / (points - 1);
7787
float ms = 1f - s;
7888
daa = fma(ms, ms, daa);
@@ -273,11 +283,11 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
273283
@Override
274284
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
275285
float nSteps = ((1 << bits) - 1);
276-
float step = (upperInterval - lowInterval) / nSteps;
286+
float invStep = nSteps / (upperInterval - lowInterval);
277287
int sumQuery = 0;
278288
for (int h = 0; h < vector.length; h++) {
279289
float xi = Math.min(Math.max(vector[h], lowInterval), upperInterval);
280-
int assignment = Math.round((xi - lowInterval) / step);
290+
int assignment = Math.round((xi - lowInterval) * invStep);
281291
sumQuery += assignment;
282292
destination[h] = assignment;
283293
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,18 @@ public interface ESVectorUtilSupport {
2929

3030
float ipFloatByte(float[] q, byte[] d);
3131

32-
float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda);
33-
34-
void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts);
32+
float calculateOSQLoss(
33+
float[] target,
34+
float lowerInterval,
35+
float upperInterval,
36+
float step,
37+
float invStep,
38+
float norm2,
39+
float lambda,
40+
int[] quantize
41+
);
42+
43+
void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts);
3544

3645
void centerAndCalculateOSQStatsEuclidean(float[] target, float[] centroid, float[] centered, float[] stats);
3746

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import jdk.incubator.vector.FloatVector;
1414
import jdk.incubator.vector.IntVector;
1515
import jdk.incubator.vector.LongVector;
16-
import jdk.incubator.vector.Vector;
1716
import jdk.incubator.vector.VectorMask;
1817
import jdk.incubator.vector.VectorOperators;
1918
import jdk.incubator.vector.VectorShape;
@@ -31,13 +30,15 @@ public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
3130
static final int VECTOR_BITSIZE;
3231

3332
private static final VectorSpecies<Float> FLOAT_SPECIES;
33+
private static final VectorSpecies<Integer> INTEGER_SPECIES;
3434
/** Whether integer vectors can be trusted to actually be fast. */
3535
static final boolean HAS_FAST_INTEGER_VECTORS;
3636

3737
static {
3838
// default to platform supported bitsize
3939
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
4040
FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
41+
INTEGER_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE));
4142

4243
// hotspot misses some SSE intrinsics, workaround it
4344
// to be fair, they do document this thing only works well with AVX2/AVX3 and Neon
@@ -270,36 +271,26 @@ public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float
270271
}
271272

272273
@Override
273-
public void calculateOSQGridPoints(float[] target, float[] interval, int points, float invStep, float[] pts) {
274-
float a = interval[0];
275-
float b = interval[1];
274+
public void calculateOSQGridPoints(float[] target, int[] quantize, int points, float[] pts) {
276275
int i = 0;
277276
float daa = 0;
278277
float dab = 0;
279278
float dbb = 0;
280279
float dax = 0;
281280
float dbx = 0;
282-
283-
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
284-
FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
285-
FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
286-
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
287-
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
288-
289281
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
290282
if (target.length > 2 * FLOAT_SPECIES.length()) {
283+
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
284+
FloatVector dabVec = FloatVector.zero(FLOAT_SPECIES);
285+
FloatVector dbbVec = FloatVector.zero(FLOAT_SPECIES);
286+
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
287+
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
291288
FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
292289
FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
293290
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
294291
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
295-
FloatVector vClamped = v.max(a).min(b);
296-
Vector<Integer> xiqint = vClamped.sub(a)
297-
.mul(invStep)
298-
// round
299-
.add(0.5f)
300-
.convert(VectorOperators.F2I, 0);
301-
FloatVector kVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
302-
FloatVector sVec = kVec.div(pmOnes);
292+
FloatVector oVec = IntVector.fromArray(INTEGER_SPECIES, quantize, i).convert(VectorOperators.I2F, 0).reinterpretAsFloats();
293+
FloatVector sVec = oVec.div(pmOnes);
303294
FloatVector smVec = ones.sub(sVec);
304295
daaVec = fma(smVec, smVec, daaVec);
305296
dabVec = fma(smVec, sVec, dabVec);
@@ -315,7 +306,7 @@ public void calculateOSQGridPoints(float[] target, float[] interval, int points,
315306
}
316307

317308
for (; i < target.length; i++) {
318-
float k = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
309+
float k = quantize[i];
319310
float s = k / (points - 1);
320311
float ms = 1f - s;
321312
daa = fma(ms, ms, daa);
@@ -333,9 +324,18 @@ public void calculateOSQGridPoints(float[] target, float[] interval, int points,
333324
}
334325

335326
@Override
336-
public float calculateOSQLoss(float[] target, float[] interval, float step, float invStep, float norm2, float lambda) {
337-
float a = interval[0];
338-
float b = interval[1];
327+
public float calculateOSQLoss(
328+
float[] target,
329+
float lowerInterval,
330+
float upperInterval,
331+
float step,
332+
float invStep,
333+
float norm2,
334+
float lambda,
335+
int[] quantize
336+
) {
337+
float a = lowerInterval;
338+
float b = upperInterval;
339339
float xe = 0f;
340340
float e = 0f;
341341
FloatVector xeVec = FloatVector.zero(FLOAT_SPECIES);
@@ -346,8 +346,10 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
346346
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
347347
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
348348
FloatVector vClamped = v.max(a).min(b);
349-
Vector<Integer> xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0);
350-
FloatVector xiq = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats().mul(step).add(a);
349+
IntVector xiqint = vClamped.sub(a).mul(invStep).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts();
350+
xiqint.intoArray(quantize, i);
351+
FloatVector quantizeVec = xiqint.convert(VectorOperators.I2F, 0).reinterpretAsFloats();
352+
FloatVector xiq = quantizeVec.mul(step).add(a);
351353
FloatVector xiiq = v.sub(xiq);
352354
xeVec = fma(v, xiiq, xeVec);
353355
eVec = fma(xiiq, xiiq, eVec);
@@ -357,8 +359,9 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
357359
}
358360

359361
for (; i < target.length; i++) {
362+
quantize[i] = Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep);
360363
// this is quantizing and then dequantizing the vector
361-
float xiq = fma(step, Math.round((Math.min(Math.max(target[i], a), b) - a) * invStep), a);
364+
float xiq = fma(step, quantize[i], a);
362365
// how much does the de-quantized value differ from the original value
363366
float xiiq = target[i] - xiq;
364367
e = fma(xiiq, xiiq, e);

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,15 +222,20 @@ public void testOsqLoss() {
222222
vecVar /= size;
223223
float vecStd = (float) Math.sqrt(vecVar);
224224

225+
int[] destinationDefault = new int[size];
226+
int[] destinationPanama = new int[size];
225227
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
226228
int points = 1 << bits;
227229
float[] initInterval = new float[2];
228230
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
229231
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
230232
float stepInv = 1f / step;
231-
float expected = defaultedProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
232-
float result = defOrPanamaProvider.getVectorUtilSupport().calculateOSQLoss(vector, initInterval, step, stepInv, norm2, 0.1f);
233+
float expected = defaultedProvider.getVectorUtilSupport()
234+
.calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationDefault);
235+
float result = defOrPanamaProvider.getVectorUtilSupport()
236+
.calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationPanama);
233237
assertEquals(expected, result, deltaEps);
238+
assertArrayEquals(destinationDefault, destinationPanama);
234239
}
235240
}
236241

@@ -240,6 +245,7 @@ public void testOsqGridPoints() {
240245
var vector = new float[size];
241246
var min = Float.MAX_VALUE;
242247
var max = -Float.MAX_VALUE;
248+
var norm2 = 0f;
243249
float vecMean = 0;
244250
float vecVar = 0;
245251
for (int i = 0; i < size; ++i) {
@@ -250,21 +256,29 @@ public void testOsqGridPoints() {
250256
vecMean += delta / (i + 1);
251257
float delta2 = vector[i] - vecMean;
252258
vecVar += delta * delta2;
259+
norm2 += vector[i] * vector[i];
253260
}
254261
vecVar /= size;
255262
float vecStd = (float) Math.sqrt(vecVar);
263+
int[] destinationDefault = new int[size];
264+
int[] destinationPanama = new int[size];
256265
for (byte bits : new byte[] { 1, 2, 3, 4, 5, 6, 7, 8 }) {
257266
int points = 1 << bits;
258267
float[] initInterval = new float[2];
259268
OptimizedScalarQuantizer.initInterval(bits, vecStd, vecMean, min, max, initInterval);
260269
float step = ((initInterval[1] - initInterval[0]) / (points - 1f));
261270
float stepInv = 1f / step;
262271
float[] expected = new float[5];
263-
defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, expected);
272+
defaultedProvider.getVectorUtilSupport()
273+
.calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationDefault);
274+
defaultedProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, destinationDefault, points, expected);
264275

265276
float[] result = new float[5];
266-
defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, initInterval, points, stepInv, result);
277+
defOrPanamaProvider.getVectorUtilSupport()
278+
.calculateOSQLoss(vector, initInterval[0], initInterval[1], step, stepInv, norm2, 0.1f, destinationPanama);
279+
defOrPanamaProvider.getVectorUtilSupport().calculateOSQGridPoints(vector, destinationPanama, points, result);
267280
assertArrayEquals(expected, result, deltaEps);
281+
assertArrayEquals(destinationDefault, destinationPanama);
268282
}
269283
}
270284

0 commit comments

Comments
 (0)