Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
float dbb = 0;
float dax = 0;
float dbx = 0;
float invPmOnes = 1f / (points - 1f);
for (int i = 0; i < target.length; ++i) {
float v = target[i];
float k = quantize[i];
float s = k / (points - 1);
float s = k * invPmOnes;
float ms = 1f - s;
daa = fma(ms, ms, daa);
dab = fma(ms, s, dab);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public void centerAndCalculateOSQStatsEuclidean(float[] vector, float[] centroid
FloatVector centeredVec = v.sub(c);
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count));
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
minVec = minVec.min(centeredVec);
Expand Down Expand Up @@ -214,7 +214,7 @@ public void centerAndCalculateOSQStatsDp(float[] vector, float[] centroid, float
FloatVector centeredVec = v.sub(c);
FloatVector deltaVec = centeredVec.sub(vecMeanVec);
norm2Vec = fma(centeredVec, centeredVec, norm2Vec);
vecMeanVec = vecMeanVec.add(deltaVec.div(count));
vecMeanVec = vecMeanVec.add(deltaVec.mul(1f / count));
FloatVector delta2Vec = centeredVec.sub(vecMeanVec);
m2Vec = fma(deltaVec, delta2Vec, m2Vec);
minVec = minVec.min(centeredVec);
Expand Down Expand Up @@ -278,6 +278,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
float dbb = 0;
float dax = 0;
float dbx = 0;
float invPmOnes = 1f / (points - 1f);
// if the array size is large (> 2x platform vector size), it's worth the overhead to vectorize
if (target.length > 2 * FLOAT_SPECIES.length()) {
FloatVector daaVec = FloatVector.zero(FLOAT_SPECIES);
Expand All @@ -286,11 +287,11 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f
FloatVector daxVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector dbxVec = FloatVector.zero(FLOAT_SPECIES);
FloatVector ones = FloatVector.broadcast(FLOAT_SPECIES, 1f);
FloatVector pmOnes = FloatVector.broadcast(FLOAT_SPECIES, points - 1f);
FloatVector invPmOnesVec = FloatVector.broadcast(FLOAT_SPECIES, invPmOnes);
for (; i < FLOAT_SPECIES.loopBound(target.length); i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, target, i);
FloatVector oVec = IntVector.fromArray(INTEGER_SPECIES, quantize, i).convert(VectorOperators.I2F, 0).reinterpretAsFloats();
FloatVector sVec = oVec.div(pmOnes);
FloatVector sVec = oVec.mul(invPmOnesVec);
FloatVector smVec = ones.sub(sVec);
daaVec = fma(smVec, smVec, daaVec);
dabVec = fma(smVec, sVec, dabVec);
Expand All @@ -307,7 +308,7 @@ public void calculateOSQGridPoints(float[] target, int[] quantize, int points, f

for (; i < target.length; i++) {
float k = quantize[i];
float s = k / (points - 1);
float s = k * invPmOnes;
float ms = 1f - s;
daa = fma(ms, ms, daa);
dab = fma(ms, s, dab);
Expand Down Expand Up @@ -798,25 +799,26 @@ public static float ipFloatByteImpl(float[] q, byte[] d) {
@Override
public int quantizeVectorWithIntervals(float[] vector, int[] destination, float lowInterval, float upperInterval, byte bits) {
float nSteps = ((1 << bits) - 1);
float step = (upperInterval - lowInterval) / nSteps;
float invStep = nSteps / (upperInterval - lowInterval);
int sumQuery = 0;
int i = 0;
if (vector.length > 2 * FLOAT_SPECIES.length()) {
int limit = FLOAT_SPECIES.loopBound(vector.length);
FloatVector lowVec = FloatVector.broadcast(FLOAT_SPECIES, lowInterval);
FloatVector upperVec = FloatVector.broadcast(FLOAT_SPECIES, upperInterval);
FloatVector stepVec = FloatVector.broadcast(FLOAT_SPECIES, step);
FloatVector invStepVec = FloatVector.broadcast(FLOAT_SPECIES, invStep);
for (; i < limit; i += FLOAT_SPECIES.length()) {
FloatVector v = FloatVector.fromArray(FLOAT_SPECIES, vector, i);
FloatVector xi = v.max(lowVec).min(upperVec); // clamp
IntVector assignment = xi.sub(lowVec).div(stepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts(); // round
// round
IntVector assignment = xi.sub(lowVec).mul(invStepVec).add(0.5f).convert(VectorOperators.F2I, 0).reinterpretAsInts();
sumQuery += assignment.reduceLanes(ADD);
assignment.intoArray(destination, i);
}
}
for (; i < vector.length; i++) {
float xi = Math.min(Math.max(vector[i], lowInterval), upperInterval);
int assignment = Math.round((xi - lowInterval) / step);
int assignment = Math.round((xi - lowInterval) * invStep);
sumQuery += assignment;
destination[i] = assignment;
}
Expand Down