diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 50b8e18c3d224..5778c26e16e56 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -237,20 +237,21 @@ public static void subtract(float[] v1, float[] v2, float[] result) { } /** - * calculates the spill-over score for a vector and a centroid, given its residual with - * its actually nearest centroid + * calculates the soar distance for a vector and a centroid * @param v1 the vector * @param centroid the centroid * @param originalResidual the residual with the actually nearest centroid - * @return the spill-over score (soar) + * @param soarLambda the lambda parameter + * @param rnorm distance to the nearest centroid + * @return the soar distance */ - public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + public static float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) { if (v1.length != centroid.length) { throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length); } if (originalResidual.length != v1.length) { throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length); } - return IMPL.soarResidual(v1, centroid, originalResidual); + return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm); } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index 005a61d323e25..022f189a2e041 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -11,6 +11,7 @@ import org.apache.lucene.util.BitUtil; import org.apache.lucene.util.Constants; +import org.apache.lucene.util.VectorUtil; final class DefaultESVectorUtilSupport implements ESVectorUtilSupport { @@ -139,15 +140,16 @@ public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float } @Override - public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) { assert v1.length == centroid.length; assert v1.length == originalResidual.length; + float dsq = VectorUtil.squareDistance(v1, centroid); float proj = 0; for (int i = 0; i < v1.length; i++) { float djk = v1[i] - centroid[i]; proj = fma(djk, originalResidual[i], proj); } - return proj; + return dsq + soarLambda * proj * proj / rnorm; } public static int ipByteBitImpl(byte[] q, byte[] d) { diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index 809b2e8a913ef..dfd324547d84c 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -37,6 +37,6 @@ public interface ESVectorUtilSupport { void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats); - float soarResidual(float[] v1, float[] centroid, float[] originalResidual); + float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index 1d8f59f855675..87e8a39c4842b 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -368,14 +368,17 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa } @Override - public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) { + public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) { assert v1.length == centroid.length; assert v1.length == originalResidual.length; float proj = 0; + float dsq = 0; int i = 0; if (v1.length > 2 * FLOAT_SPECIES.length()) { FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES); FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES); + FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES); int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length(); for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) { // one @@ -384,6 +387,7 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); FloatVector djkVec0 = v1Vec0.sub(centroidVec0); projVec1 = fma(djkVec0, originalResidualVec0, projVec1); + acc1 = fma(djkVec0, djkVec0, acc1); // two FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length()); @@ -391,6 +395,7 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length()); FloatVector djkVec1 = v1Vec1.sub(centroidVec1); projVec2 = fma(djkVec1, originalResidualVec1, projVec2); + acc2 = fma(djkVec1, djkVec1, acc2); } // vector tail for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) { @@ -399,15 +404,18 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); FloatVector djkVec = v1Vec.sub(centroidVec); projVec1 = fma(djkVec, originalResidualVec, projVec1); + acc1 = fma(djkVec, djkVec, acc1); } proj += projVec1.add(projVec2).reduceLanes(ADD); + dsq += acc1.add(acc2).reduceLanes(ADD); } // tail for (; i < v1.length; i++) { float djk = v1[i] - centroid[i]; proj = fma(djk, originalResidual[i], proj); + dsq = fma(djk, djk, dsq); } - return proj; + return dsq + soarLambda * proj * proj / rnorm; } private static final VectorSpecies BYTE_SPECIES_128 = ByteVector.SPECIES_128; diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index abd4e3b0be045..2a83c0ec62125 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -268,7 +268,7 @@ public void testOsqGridPoints() { } } - public void testSoarOverspillScore() { + public void testSoarDistance() { int size = random().nextInt(128, 512); float deltaEps = 1e-5f * size; var vector = new float[size]; @@ -279,8 +279,10 @@ public void testSoarOverspillScore() { centroid[i] = random().nextFloat(); preResidual[i] = random().nextFloat(); } - var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual); - var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual); + float soarLambda = random().nextFloat(); + float rnorm = random().nextFloat(); + var expected = defaultedProvider.getVectorUtilSupport().soarDistance(vector, centroid, preResidual, soarLambda, rnorm); + var result = defOrPanamaProvider.getVectorUtilSupport().soarDistance(vector, centroid, preResidual, soarLambda, rnorm); assertEquals(expected, result, deltaEps); } diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index 1551a1cfd0b6e..b1303b7124b24 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -202,7 +202,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods continue; } float[] neighborCentroid = centroids[neighbor]; - float soar = distanceSoar(diffs, vector, neighborCentroid, vectorCentroidDist); + float soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist); if (soar < minSoar) { bestAssignment = neighbor; minSoar = soar; @@ -215,13 +215,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List neighborhoods return spilledAssignments; } - private float distanceSoar(float[] residual, float[] vector, float[] centroid, float rnorm) { - // TODO: combine these to be more efficient - float dsq = VectorUtil.squareDistance(vector, centroid); - float rproj = ESVectorUtil.soarResidual(vector, centroid, residual); - return dsq + soarLambda * rproj * rproj / rnorm; - } - /** * cluster using a lloyd k-means algorithm that is not neighbor aware *