Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,21 @@ public static void subtract(float[] v1, float[] v2, float[] result) {
}

/**
* calculates the spill-over score for a vector and a centroid, given its residual with
* its actually nearest centroid
* calculates the soar distance for a vector and a centroid
* @param v1 the vector
* @param centroid the centroid
* @param originalResidual the residual with the actually nearest centroid
* @return the spill-over score (soar)
* @param soarLambda the lambda parameter
* @param rnorm distance to the nearest centroid
* @return the soar distance
*/
public static float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
public static float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
if (v1.length != centroid.length) {
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + centroid.length);
}
if (originalResidual.length != v1.length) {
throw new IllegalArgumentException("vector dimensions differ: " + originalResidual.length + "!=" + v1.length);
}
return IMPL.soarResidual(v1, centroid, originalResidual);
return IMPL.soarDistance(v1, centroid, originalResidual, soarLambda, rnorm);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import org.apache.lucene.util.BitUtil;
import org.apache.lucene.util.Constants;
import org.apache.lucene.util.VectorUtil;

final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {

Expand Down Expand Up @@ -139,15 +140,16 @@ public void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float
}

@Override
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
assert v1.length == centroid.length;
assert v1.length == originalResidual.length;
float dsq = VectorUtil.squareDistance(v1, centroid);
float proj = 0;
for (int i = 0; i < v1.length; i++) {
float djk = v1[i] - centroid[i];
proj = fma(djk, originalResidual[i], proj);
}
return proj;
return dsq + soarLambda * proj * proj / rnorm;
}

public static int ipByteBitImpl(byte[] q, byte[] d) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,6 @@ public interface ESVectorUtilSupport {

void centerAndCalculateOSQStatsDp(float[] target, float[] centroid, float[] centered, float[] stats);

float soarResidual(float[] v1, float[] centroid, float[] originalResidual);
float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm);

}
Original file line number Diff line number Diff line change
Expand Up @@ -368,14 +368,17 @@ public float calculateOSQLoss(float[] target, float[] interval, float step, floa
}

@Override
public float soarResidual(float[] v1, float[] centroid, float[] originalResidual) {
public float soarDistance(float[] v1, float[] centroid, float[] originalResidual, float soarLambda, float rnorm) {
assert v1.length == centroid.length;
assert v1.length == originalResidual.length;
float proj = 0;
float dsq = 0;
int i = 0;
if (v1.length > 2 * FLOAT_SPECIES.length()) {
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
int unrolledLimit = FLOAT_SPECIES.loopBound(v1.length) - FLOAT_SPECIES.length();
for (; i < unrolledLimit; i += 2 * FLOAT_SPECIES.length()) {
// one
Expand All @@ -384,13 +387,15 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual
FloatVector originalResidualVec0 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec0 = v1Vec0.sub(centroidVec0);
projVec1 = fma(djkVec0, originalResidualVec0, projVec1);
acc1 = fma(djkVec0, djkVec0, acc1);

// two
FloatVector v1Vec1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i + FLOAT_SPECIES.length());
FloatVector centroidVec1 = FloatVector.fromArray(FLOAT_SPECIES, centroid, i + FLOAT_SPECIES.length());
FloatVector originalResidualVec1 = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i + FLOAT_SPECIES.length());
FloatVector djkVec1 = v1Vec1.sub(centroidVec1);
projVec2 = fma(djkVec1, originalResidualVec1, projVec2);
acc2 = fma(djkVec1, djkVec1, acc2);
}
// vector tail
for (; i < FLOAT_SPECIES.loopBound(v1.length); i += FLOAT_SPECIES.length()) {
Expand All @@ -399,15 +404,18 @@ public float soarResidual(float[] v1, float[] centroid, float[] originalResidual
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
FloatVector djkVec = v1Vec.sub(centroidVec);
projVec1 = fma(djkVec, originalResidualVec, projVec1);
acc1 = fma(djkVec, djkVec, acc1);
}
proj += projVec1.add(projVec2).reduceLanes(ADD);
dsq += acc1.add(acc2).reduceLanes(ADD);
}
// tail
for (; i < v1.length; i++) {
float djk = v1[i] - centroid[i];
proj = fma(djk, originalResidual[i], proj);
dsq = fma(djk, djk, dsq);
}
return proj;
return dsq + soarLambda * proj * proj / rnorm;
}

private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ public void testOsqGridPoints() {
}
}

public void testSoarOverspillScore() {
public void testSoarDistance() {
int size = random().nextInt(128, 512);
float deltaEps = 1e-5f * size;
var vector = new float[size];
Expand All @@ -279,8 +279,10 @@ public void testSoarOverspillScore() {
centroid[i] = random().nextFloat();
preResidual[i] = random().nextFloat();
}
var expected = defaultedProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
var result = defOrPanamaProvider.getVectorUtilSupport().soarResidual(vector, centroid, preResidual);
float soarLambda = random().nextFloat();
float rnorm = random().nextFloat();
var expected = defaultedProvider.getVectorUtilSupport().soarDistance(vector, centroid, preResidual, soarLambda, rnorm);
var result = defOrPanamaProvider.getVectorUtilSupport().soarDistance(vector, centroid, preResidual, soarLambda, rnorm);
assertEquals(expected, result, deltaEps);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
continue;
}
float[] neighborCentroid = centroids[neighbor];
float soar = distanceSoar(diffs, vector, neighborCentroid, vectorCentroidDist);
float soar = ESVectorUtil.soarDistance(vector, neighborCentroid, diffs, soarLambda, vectorCentroidDist);
if (soar < minSoar) {
bestAssignment = neighbor;
minSoar = soar;
Expand All @@ -215,13 +215,6 @@ private int[] assignSpilled(FloatVectorValues vectors, List<int[]> neighborhoods
return spilledAssignments;
}

private float distanceSoar(float[] residual, float[] vector, float[] centroid, float rnorm) {
// TODO: combine these to be more efficient
float dsq = VectorUtil.squareDistance(vector, centroid);
float rproj = ESVectorUtil.soarResidual(vector, centroid, residual);
return dsq + soarLambda * rproj * rproj / rnorm;
}

/**
* cluster using a lloyd k-means algorithm that is not neighbor aware
*
Expand Down