@@ -132,7 +132,7 @@ public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid
132132 FloatVector centeredVec = v .sub (c );
133133 FloatVector deltaVec = centeredVec .sub (vecMeanVec );
134134 norm2Vec = fma (centeredVec , centeredVec , norm2Vec );
135- vecMeanVec = vecMeanVec .add (deltaVec .div ( count ));
135+ vecMeanVec = vecMeanVec .add (deltaVec .mul ( 1f / count ));
136136 FloatVector delta2Vec = centeredVec .sub (vecMeanVec );
137137 m2Vec = fma (deltaVec , delta2Vec , m2Vec );
138138 minVec = minVec .min (centeredVec );
@@ -214,7 +214,7 @@ public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float
214214 FloatVector centeredVec = v .sub (c );
215215 FloatVector deltaVec = centeredVec .sub (vecMeanVec );
216216 norm2Vec = fma (centeredVec , centeredVec , norm2Vec );
217- vecMeanVec = vecMeanVec .add (deltaVec .div ( count ));
217+ vecMeanVec = vecMeanVec .add (deltaVec .mul ( 1f / count ));
218218 FloatVector delta2Vec = centeredVec .sub (vecMeanVec );
219219 m2Vec = fma (deltaVec , delta2Vec , m2Vec );
220220 minVec = minVec .min (centeredVec );
@@ -278,6 +278,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
278278 float dbb = 0 ;
279279 float dax = 0 ;
280280 float dbx = 0 ;
281+ float invPmOnes = 1f / (points - 1f );
281282 // if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
282283 if (target .length > 2 * FLOAT_SPECIES .length ()) {
283284 FloatVector daaVec = FloatVector .zero (FLOAT_SPECIES );
@@ -286,11 +287,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
286287 FloatVector daxVec = FloatVector .zero (FLOAT_SPECIES );
287288 FloatVector dbxVec = FloatVector .zero (FLOAT_SPECIES );
288289 FloatVector ones = FloatVector .broadcast (FLOAT_SPECIES , 1f );
289- FloatVector pmOnes = FloatVector .broadcast (FLOAT_SPECIES , points - 1f );
290+ FloatVector invPmOnesVec = FloatVector .broadcast (FLOAT_SPECIES , invPmOnes );
290291 for (; i < FLOAT_SPECIES .loopBound (target .length ); i += FLOAT_SPECIES .length ()) {
291292 FloatVector v = FloatVector .fromArray (FLOAT_SPECIES , target , i );
292293 FloatVector oVec = IntVector .fromArray (INTEGER_SPECIES , quantize , i ).convert (VectorOperators .I2F , 0 ).reinterpretAsFloats ();
293- FloatVector sVec = oVec .div ( pmOnes );
294+ FloatVector sVec = oVec .mul ( invPmOnesVec );
294295 FloatVector smVec = ones .sub (sVec );
295296 daaVec = fma (smVec , smVec , daaVec );
296297 dabVec = fma (smVec , sVec , dabVec );
@@ -307,7 +308,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
307308
308309 for (; i < target .length ; i ++) {
309310 float k = quantize [i ];
310- float s = k / ( points - 1 ) ;
311+ float s = k * invPmOnes ;
311312 float ms = 1f - s ;
312313 daa = fma (ms , ms , daa );
313314 dab = fma (ms , s , dab );
@@ -798,25 +799,25 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
798799 @ Override
799800 public int quantizeVectorWithIntervals (float [] vector , int [] destination , float lowInterval , float upperInterval , byte bits ) {
800801 float nSteps = ((1 << bits ) - 1 );
801- float step = (upperInterval - lowInterval ) / nSteps ;
802+ float invStep = nSteps / (upperInterval - lowInterval );
802803 int sumQuery = 0 ;
803804 int i = 0 ;
804805 if (vector .length > 2 * FLOAT_SPECIES .length ()) {
805806 int limit = FLOAT_SPECIES .loopBound (vector .length );
806807 FloatVector lowVec = FloatVector .broadcast (FLOAT_SPECIES , lowInterval );
807808 FloatVector upperVec = FloatVector .broadcast (FLOAT_SPECIES , upperInterval );
808- FloatVector stepVec = FloatVector .broadcast (FLOAT_SPECIES , step );
809+ FloatVector invStepVec = FloatVector .broadcast (FLOAT_SPECIES , invStep );
809810 for (; i < limit ; i += FLOAT_SPECIES .length ()) {
810811 FloatVector v = FloatVector .fromArray (FLOAT_SPECIES , vector , i );
811812 FloatVector xi = v .max (lowVec ).min (upperVec ); // clamp
812- IntVector assignment = xi .sub (lowVec ).div ( stepVec ).add (0.5f ).convert (VectorOperators .F2I , 0 ).reinterpretAsInts (); // round
813+ IntVector assignment = xi .sub (lowVec ).mul ( invStepVec ).add (0.5f ).convert (VectorOperators .F2I , 0 ).reinterpretAsInts (); // round
813814 sumQuery += assignment .reduceLanes (ADD );
814815 assignment .intoArray (destination , i );
815816 }
816817 }
817818 for (; i < vector .length ; i ++) {
818819 float xi = Math .min (Math .max (vector [i ], lowInterval ), upperInterval );
819- int assignment = Math .round ((xi - lowInterval ) / step );
820+ int assignment = Math .round ((xi - lowInterval ) * invStep );
820821 sumQuery += assignment ;
821822 destination [i ] = assignment ;
822823 }
0 commit comments