Skip to content

Commit d571325

Browse files
committed
Small optimization in OptimizedScalarQuantizer by using mul instead of div
1 parent 8d04055 commit d571325

File tree

2 files changed

+12
-10
lines changed

2 files changed

+12
-10
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
8080
float dbb = 0;
8181
float dax = 0;
8282
float dbx = 0;
83+
float invPmOnes = 1f / (points - 1f);
8384
for (int i = 0; i < target.length; ++i) {
8485
float v = target[i];
8586
float k = quantize[i];
86-
float s = k / (points - 1);
87+
float s = k * invPmOnes;
8788
float ms = 1f - s;
8889
daa = fma(ms, ms, daa);
8990
dab = fma(ms, s, dab);

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid
132132
FloatVector centeredVec = v.sub(c);
133133
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
134134
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
135-
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
135+
vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count));
136136
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
137137
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
138138
minVec = minVec.min(centeredVec);
@@ -214,7 +214,7 @@ public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float
214214
FloatVector centeredVec = v.sub(c);
215215
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
216216
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
217-
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
217+
vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count));
218218
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
219219
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
220220
minVec = minVec.min(centeredVec);
@@ -278,6 +278,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
278278
float dbb = 0;
279279
float dax = 0;
280280
float dbx = 0;
281+
float invPmOnes = 1f / (points - 1f);
281282
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
282283
if (target.length > 2 * FLOAT_SPECIES.length()) {
283284
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
@@ -286,11 +287,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
286287
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
287288
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
288289
FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
289-
FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
290+
FloatVector invPmOnesVec = FloatVector.broadcast(FLOAT_SPECIES, invPmOnes);
290291
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
291292
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
292293
FloatVector oVec = IntVector.fromArray(INTEGER_SPECIES, quantize, i).convert(VectorOperators.I2F, 0).reinterpretAsFloats();
293-
FloatVector sVec = oVec.div(pmOnes);
294+
FloatVector sVec = oVec.mul(invPmOnesVec);
294295
FloatVector smVec = ones.sub(sVec);
295296
daaVec = fma(smVec, smVec, daaVec);
296297
dabVec = fma(smVec, sVec, dabVec);
@@ -307,7 +308,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
307308

308309
for (; i < target.length; i++) {
309310
float k = quantize[i];
310-
float s = k / (points - 1);
311+
float s = k * invPmOnes;
311312
float ms = 1f - s;
312313
daa = fma(ms, ms, daa);
313314
dab = fma(ms, s, dab);
@@ -798,25 +799,25 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
798799
@Override
799800
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
800801
float nSteps = ((1 << bits) - 1);
801-
float step = (upperInterval - lowInterval) / nSteps;
802+
float invStep = nSteps / (upperInterval - lowInterval);
802803
int sumQuery = 0;
803804
int i = 0;
804805
if (vector.length > 2 * FLOAT_SPECIES.length()) {
805806
int limit = FLOAT_SPECIES.loopBound(vector.length);
806807
FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
807808
FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
808-
FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
809+
FloatVector invStepVec = FloatVector.broadcast(FLOAT_SPECIES, invStep);
809810
for (; i < limit; i += FLOAT_SPECIES.length()) {
810811
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
811812
FloatVector xi = v.max(lowVec).min(upperVec); // clamp
812-
IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
813+
IntVector assignment = xi.sub(lowVec).mul(invStepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
813814
sumQuery += assignment.reduceLanes(ADD);
814815
assignment.intoArray(destination, i);
815816
}
816817
}
817818
for (; i < vector.length; i++) {
818819
float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
819-
int assignment = Math.round((xi - lowInterval) / step);
820+
int assignment = Math.round((xi - lowInterval) * invStep);
820821
sumQuery += assignment;
821822
destination[i] = assignment;
822823
}

0 commit comments

Comments
 (0)