diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index d4aedaf99a1f7..12abda2506252 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -80,10 +80,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f float dbb = 0; float dax = 0; float dbx = 0; + float invPmOnes = 1f / (points - 1f); for (int i = 0; i < target.length; ++i) { float v = target[i]; float k = quantize[i]; - float s = k / (points - 1); + float s = k * invPmOnes; float ms = 1f - s; daa = fma(ms, ms, daa); dab = fma(ms, s, dab); diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 1196661ae010f..2a5f633d51b78 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -132,7 +132,7 @@ public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid FloatVector centeredVec = v.sub(c); FloatVector deltaVec = centeredVec.sub(vecMeanVec); norm2Vec = fma(centeredVec, centeredVec, norm2Vec); - vecMeanVec = vecMeanVec.add(deltaVec.div(count)); + vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count)); FloatVector delta2Vec = centeredVec.sub(vecMeanVec); m2Vec = fma(deltaVec, delta2Vec, m2Vec); minVec = minVec.min(centeredVec); @@ -214,7 +214,7 @@ public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float FloatVector centeredVec = v.sub(c); FloatVector deltaVec = centeredVec.sub(vecMeanVec); norm2Vec = fma(centeredVec, centeredVec, norm2Vec); - vecMeanVec = vecMeanVec.add(deltaVec.div(count)); + vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count)); FloatVector delta2Vec = centeredVec.sub(vecMeanVec); m2Vec = fma(deltaVec, delta2Vec, m2Vec); minVec = minVec.min(centeredVec); @@ -278,6 +278,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f float dbb = 0; float dax = 0; float dbx = 0; + float invPmOnes = 1f / (points - 1f); // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize if (target.length > 2 * FLOAT_SPECIES.length()) { FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES); @@ -286,11 +287,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES); FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES); FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f); - FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f); + FloatVector invPmOnesVec = FloatVector.broadcast(FLOAT_SPECIES, invPmOnes); for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) { FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i); FloatVector oVec = IntVector.fromArray(INTEGER_SPECIES, quantize, i).convert(VectorOperators.I2F, 0).reinterpretAsFloats(); - FloatVector sVec = oVec.div(pmOnes); + FloatVector sVec = oVec.mul(invPmOnesVec); FloatVector smVec = ones.sub(sVec); daaVec = fma(smVec, smVec, daaVec); dabVec = fma(smVec, sVec, dabVec); @@ -307,7 +308,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f for (; i < target.length; i++) { float k = quantize[i]; - float s = k / (points - 1); + float s = k * invPmOnes; float ms = 1f - s; daa = fma(ms, ms, daa); dab = fma(ms, s, dab); @@ -798,25 +799,26 @@ public static float ipFloatByteImpl(float[] q, byte[] d) { @Override public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) { float nSteps = ((1 << bits) - 1); - float step = (upperInterval - lowInterval) / nSteps; + float invStep = nSteps / (upperInterval - lowInterval); int sumQuery = 0; int i = 0; if (vector.length > 2 * FLOAT_SPECIES.length()) { int limit = FLOAT_SPECIES.loopBound(vector.length); FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval); FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval); - FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step); + FloatVector invStepVec = FloatVector.broadcast(FLOAT_SPECIES, invStep); for (; i < limit; i += FLOAT_SPECIES.length()) { FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i); FloatVector xi = v.max(lowVec).min(upperVec); // clamp - IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round + // round + IntVector assignment = xi.sub(lowVec).mul(invStepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); sumQuery += assignment.reduceLanes(ADD); assignment.intoArray(destination, i); } } for (; i < vector.length; i++) { float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval); - int assignment = Math.round((xi - lowInterval) / step); + int assignment = Math.round((xi - lowInterval) * invStep); sumQuery += assignment; destination[i] = assignment; }