diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceBulkBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceBulkBenchmark.java new file mode 100644 index 0000000000000..2de50622fd6fc --- /dev/null +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceBulkBenchmark.java @@ -0,0 +1,145 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ +package org.elasticsearch.benchmark.vector; + +import org.apache.lucene.util.VectorUtil; +import org.apache.lucene.util.quantization.OptimizedScalarQuantizer; +import org.elasticsearch.common.logging.LogConfigurator; +import org.elasticsearch.simdvec.ESVectorUtil; +import org.openjdk.jmh.annotations.Benchmark; +import org.openjdk.jmh.annotations.BenchmarkMode; +import org.openjdk.jmh.annotations.Fork; +import org.openjdk.jmh.annotations.Measurement; +import org.openjdk.jmh.annotations.Mode; +import org.openjdk.jmh.annotations.OutputTimeUnit; +import org.openjdk.jmh.annotations.Param; +import org.openjdk.jmh.annotations.Scope; +import org.openjdk.jmh.annotations.Setup; +import org.openjdk.jmh.annotations.State; +import org.openjdk.jmh.annotations.Warmup; +import org.openjdk.jmh.infra.Blackhole; + +import java.io.IOException; +import java.util.Random; +import java.util.concurrent.TimeUnit; + +@BenchmarkMode(Mode.Throughput) +@OutputTimeUnit(TimeUnit.MILLISECONDS) +@State(Scope.Benchmark) +// first iteration is complete garbage, so make sure we really warmup +@Warmup(iterations = 4, time = 1) +// real iterations. not useful to spend tons of time here, better to fork more +@Measurement(iterations = 5, time = 1) +// engage some noise reduction +@Fork(value = 1) +public class DistanceBulkBenchmark { + + static { + LogConfigurator.configureESLogging(); // native access requires logging to be initialized + } + + @Param({ "384", "782", "1024" }) + int dims; + + int length; + + int numVectors = 4 * 100; + int numQueries = 10; + + float[][] vectors; + float[][] queries; + float[] distances = new float[4]; + + @Setup + public void setup() throws IOException { + Random random = new Random(123); + + this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8; + + vectors = new float[numVectors][dims]; + for (float[] vector : vectors) { + for (int i = 0; i < dims; i++) { + vector[i] = random.nextFloat(); + } + } + + queries = new float[numQueries][dims]; + for (float[] query : queries) { + for (int i = 0; i < dims; i++) { + query[i] = random.nextFloat(); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void squareDistance(Blackhole bh) { + for (int j = 0; j < numQueries; j++) { + float[] query = queries[j]; + for (int i = 0; i < numVectors; i++) { + float[] vector = vectors[i]; + float distance = VectorUtil.squareDistance(query, vector); + bh.consume(distance); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void soarDistance(Blackhole bh) { + for (int j = 0; j < numQueries; j++) { + float[] query = queries[j]; + for (int i = 0; i < numVectors; i++) { + float[] vector = vectors[i]; + float distance = ESVectorUtil.soarDistance(query, vector, vector, 1.0f, 1.0f); + bh.consume(distance); + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void squareDistanceBulk(Blackhole bh) { + for (int j = 0; j < numQueries; j++) { + float[] query = queries[j]; + for (int i = 0; i < numVectors; i += 4) { + ESVectorUtil.squareDistanceBulk(query, vectors[i], vectors[i + 1], vectors[i + 2], vectors[i + 3], distances); + for (float distance : distances) { + bh.consume(distance); + } + + } + } + } + + @Benchmark + @Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" }) + public void soarDistanceBulk(Blackhole bh) { + for (int j = 0; j < numQueries; j++) { + float[] query = queries[j]; + for (int i = 0; i < numVectors; i += 4) { + ESVectorUtil.soarDistanceBulk( + query, + vectors[i], + vectors[i + 1], + vectors[i + 2], + vectors[i + 3], + vectors[i], + 1.0f, + 1.0f, + distances + ); + for (float distance : distances) { + bh.consume(distance); + } + + } + } + } +} diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java index 5b14b39d37fb0..a02370a89f931 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java @@ -293,4 +293,79 @@ public static int quantizeVectorWithIntervals(float[] vector, int[] destination, } return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit); } + + /** + * Bulk computation of square distances between a query vector and four vectors.Result is stored in the provided distances array. + * + * @param q the query vector + * @param v0 the first vector + * @param v1 the second vector + * @param v2 the third vector + * @param v3 the fourth vector + * @param distances an array to store the computed square distances, must have length 4 + * + * @throws IllegalArgumentException if the dimensions of the vectors do not match or if the distances array does not have length 4 + */ + public static void squareDistanceBulk(float[] q, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) { + if (q.length != v0.length) { + throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v0.length); + } + if (q.length != v1.length) { + throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v1.length); + } + if (q.length != v2.length) { + throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v2.length); + } + if (q.length != v3.length) { + throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v3.length); + } + if (distances.length != 4) { + throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length); + } + IMPL.squareDistanceBulk(q, v0, v1, v2, v3, distances); + } + + /** + * Bulk computation of the soar distance for a vector to four centroids + * @param v1 the vector + * @param c0 the first centroid + * @param c1 the second centroid + * @param c2 the third centroid + * @param c3 the fourth centroid + * @param originalResidual the residual with the actually nearest centroid + * @param soarLambda the lambda parameter + * @param rnorm distance to the nearest centroid + * @param distances an array to store the computed soar distances, must have length 4 + */ + public static void soarDistanceBulk( + float[] v1, + float[] c0, + float[] c1, + float[] c2, + float[] c3, + float[] originalResidual, + float soarLambda, + float rnorm, + float[] distances + ) { + if (v1.length != c0.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c0.length); + } + if (v1.length != c1.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c1.length); + } + if (v1.length != c2.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c2.length); + } + if (v1.length != c3.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c3.length); + } + if (v1.length != originalResidual.length) { + throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + originalResidual.length); + } + if (distances.length != 4) { + throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length); + } + IMPL.soarDistanceBulk(v1, c0, c1, c2, c3, originalResidual, soarLambda, rnorm, distances); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java index 19e7a876ff202..d4aedaf99a1f7 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java @@ -293,4 +293,30 @@ public int quantizeVectorWithIntervals(float[] vector, int[] destination, float } return sumQuery; } + + @Override + public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) { + distances[0] = VectorUtil.squareDistance(query, v0); + distances[1] = VectorUtil.squareDistance(query, v1); + distances[2] = VectorUtil.squareDistance(query, v2); + distances[3] = VectorUtil.squareDistance(query, v3); + } + + @Override + public void soarDistanceBulk( + float[] v1, + float[] c0, + float[] c1, + float[] c2, + float[] c3, + float[] originalResidual, + float soarLambda, + float rnorm, + float[] distances + ) { + distances[0] = soarDistance(v1, c0, originalResidual, soarLambda, rnorm); + distances[1] = soarDistance(v1, c1, originalResidual, soarLambda, rnorm); + distances[2] = soarDistance(v1, c2, originalResidual, soarLambda, rnorm); + distances[3] = soarDistance(v1, c3, originalResidual, soarLambda, rnorm); + } } diff --git a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java index 9453a5823bc68..895105a452b0c 100644 --- a/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java +++ b/libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java @@ -50,4 +50,17 @@ float calculateOSQLoss( int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit); + void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances); + + void soarDistanceBulk( + float[] v1, + float[] c0, + float[] c1, + float[] c2, + float[] c3, + float[] originalResidual, + float soarLambda, + float rnorm, + float[] distances + ); } diff --git a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java index cc4d8b931a692..1196661ae010f 100644 --- a/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java +++ b/libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java @@ -822,4 +822,122 @@ public int quantizeVectorWithIntervals(float[] vector, int[] destination, float } return sumQuery; } + + @Override + public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) { + FloatVector sv0 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES); + FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES); + final int limit = FLOAT_SPECIES.loopBound(query.length); + int i = 0; + for (; i < limit; i += FLOAT_SPECIES.length()) { + FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, query, i); + FloatVector dv0 = FloatVector.fromArray(FLOAT_SPECIES, v0, i); + FloatVector dv1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i); + FloatVector dv2 = FloatVector.fromArray(FLOAT_SPECIES, v2, i); + FloatVector dv3 = FloatVector.fromArray(FLOAT_SPECIES, v3, i); + FloatVector diff0 = qv.sub(dv0); + sv0 = fma(diff0, diff0, sv0); + FloatVector diff1 = qv.sub(dv1); + sv1 = fma(diff1, diff1, sv1); + FloatVector diff2 = qv.sub(dv2); + sv2 = fma(diff2, diff2, sv2); + FloatVector diff3 = qv.sub(dv3); + sv3 = fma(diff3, diff3, sv3); + } + float distance0 = sv0.reduceLanes(VectorOperators.ADD); + float distance1 = sv1.reduceLanes(VectorOperators.ADD); + float distance2 = sv2.reduceLanes(VectorOperators.ADD); + float distance3 = sv3.reduceLanes(VectorOperators.ADD); + + for (; i < query.length; i++) { + final float qValue = query[i]; + final float diff0 = qValue - v0[i]; + final float diff1 = qValue - v1[i]; + final float diff2 = qValue - v2[i]; + final float diff3 = qValue - v3[i]; + distance0 = fma(diff0, diff0, distance0); + distance1 = fma(diff1, diff1, distance1); + distance2 = fma(diff2, diff2, distance2); + distance3 = fma(diff3, diff3, distance3); + } + distances[0] = distance0; + distances[1] = distance1; + distances[2] = distance2; + distances[3] = distance3; + } + + @Override + public void soarDistanceBulk( + float[] v1, + float[] c0, + float[] c1, + float[] c2, + float[] c3, + float[] originalResidual, + float soarLambda, + float rnorm, + float[] distances + ) { + + FloatVector projVec0 = FloatVector.zero(FLOAT_SPECIES); + FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES); + FloatVector projVec3 = FloatVector.zero(FLOAT_SPECIES); + FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES); + FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES); + FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES); + FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES); + final int limit = FLOAT_SPECIES.loopBound(v1.length); + int i = 0; + for (; i < limit; i += FLOAT_SPECIES.length()) { + FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i); + FloatVector c0Vec = FloatVector.fromArray(FLOAT_SPECIES, c0, i); + FloatVector c1Vec = FloatVector.fromArray(FLOAT_SPECIES, c1, i); + FloatVector c2Vec = FloatVector.fromArray(FLOAT_SPECIES, c2, i); + FloatVector c3Vec = FloatVector.fromArray(FLOAT_SPECIES, c3, i); + FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i); + FloatVector djkVec0 = v1Vec.sub(c0Vec); + FloatVector djkVec1 = v1Vec.sub(c1Vec); + FloatVector djkVec2 = v1Vec.sub(c2Vec); + FloatVector djkVec3 = v1Vec.sub(c3Vec); + projVec0 = fma(djkVec0, originalResidualVec, projVec0); + projVec1 = fma(djkVec1, originalResidualVec, projVec1); + projVec2 = fma(djkVec2, originalResidualVec, projVec2); + projVec3 = fma(djkVec3, originalResidualVec, projVec3); + acc0 = fma(djkVec0, djkVec0, acc0); + acc1 = fma(djkVec1, djkVec1, acc1); + acc2 = fma(djkVec2, djkVec2, acc2); + acc3 = fma(djkVec3, djkVec3, acc3); + } + float proj0 = projVec0.reduceLanes(ADD); + float dsq0 = acc0.reduceLanes(ADD); + float proj1 = projVec1.reduceLanes(ADD); + float dsq1 = acc1.reduceLanes(ADD); + float proj2 = projVec2.reduceLanes(ADD); + float dsq2 = acc2.reduceLanes(ADD); + float proj3 = projVec3.reduceLanes(ADD); + float dsq3 = acc3.reduceLanes(ADD); + // tail + for (; i < v1.length; i++) { + float v = v1[i]; + float djk0 = v - c0[i]; + float djk1 = v - c1[i]; + float djk2 = v - c2[i]; + float djk3 = v - c3[i]; + proj0 = fma(djk0, originalResidual[i], proj0); + proj1 = fma(djk1, originalResidual[i], proj1); + proj2 = fma(djk2, originalResidual[i], proj2); + proj3 = fma(djk3, originalResidual[i], proj3); + dsq0 = fma(djk0, djk0, dsq0); + dsq1 = fma(djk1, djk1, dsq1); + dsq2 = fma(djk2, djk2, dsq2); + dsq3 = fma(djk3, djk3, dsq3); + } + distances[0] = dsq0 + soarLambda * proj0 * proj0 / rnorm; + distances[1] = dsq1 + soarLambda * proj1 * proj1 / rnorm; + distances[2] = dsq2 + soarLambda * proj2 * proj2 / rnorm; + distances[3] = dsq3 + soarLambda * proj3 * proj3 / rnorm; + } } diff --git a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java index 9cc215ba595a7..b51fc25fab9f1 100644 --- a/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java +++ b/libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java @@ -323,6 +323,46 @@ public void testQuantizeVectorWithIntervals() { assertEquals(expected, result, 0f); } + public void testSquareDistanceBulk() { + int vectorSize = randomIntBetween(1, 2048); + float[] query = generateRandomVector(vectorSize); + float[] v0 = generateRandomVector(vectorSize); + float[] v1 = generateRandomVector(vectorSize); + float[] v2 = generateRandomVector(vectorSize); + float[] v3 = generateRandomVector(vectorSize); + float[] expectedDistances = new float[4]; + float[] panamaDistances = new float[4]; + defaultedProvider.getVectorUtilSupport().squareDistanceBulk(query, v0, v1, v2, v3, expectedDistances); + defOrPanamaProvider.getVectorUtilSupport().squareDistanceBulk(query, v0, v1, v2, v3, panamaDistances); + assertArrayEquals(expectedDistances, panamaDistances, 1e-3f); + } + + public void testSoarDistanceBulk() { + int vectorSize = randomIntBetween(1, 2048); + float deltaEps = 1e-3f * vectorSize; + float[] query = generateRandomVector(vectorSize); + float[] v0 = generateRandomVector(vectorSize); + float[] v1 = generateRandomVector(vectorSize); + float[] v2 = generateRandomVector(vectorSize); + float[] v3 = generateRandomVector(vectorSize); + float[] diff = generateRandomVector(vectorSize); + float soarLambda = random().nextFloat(); + float rnorm = random().nextFloat(); + float[] expectedDistances = new float[4]; + float[] panamaDistances = new float[4]; + defaultedProvider.getVectorUtilSupport().soarDistanceBulk(query, v0, v1, v2, v3, diff, soarLambda, rnorm, expectedDistances); + defOrPanamaProvider.getVectorUtilSupport().soarDistanceBulk(query, v0, v1, v2, v3, diff, soarLambda, rnorm, panamaDistances); + assertArrayEquals(expectedDistances, panamaDistances, deltaEps); + } + + private float[] generateRandomVector(int size) { + float[] vector = new float[size]; + for (int i = 0; i < size; ++i) { + vector[i] = random().nextFloat(); + } + return vector; + } + void testIpByteBinImpl(ToLongBiFunction ipByteBinFunc) { int iterations = atLeast(50); for (int i = 0; i < iterations; i++) { diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java index a1e480fb73266..0aabdc9d74590 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java @@ -81,16 +81,16 @@ private static boolean stepLloyd( for (float[] nextCentroid : nextCentroids) { Arrays.fill(nextCentroid, 0.0f); } - + final float[] distances = new float[4]; for (int idx = 0; idx < vectors.size(); idx++) { float[] vector = vectors.vectorValue(idx); int vectorOrd = translateOrd.apply(idx); final int assignment = assignments[vectorOrd]; final int bestCentroidOffset; if (neighborhoods != null) { - bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods[assignment]); + bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods[assignment], distances); } else { - bestCentroidOffset = getBestCentroid(centroids, vector); + bestCentroidOffset = getBestCentroid(centroids, vector, distances); } if (assignment != bestCentroidOffset) { assignments[vectorOrd] = bestCentroidOffset; @@ -114,19 +114,49 @@ private static boolean stepLloyd( return changed; } - private static int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, NeighborHood neighborhood) { + private static int getBestCentroidFromNeighbours( + float[][] centroids, + float[] vector, + int centroidIdx, + NeighborHood neighborhood, + float[] distances + ) { + final int limit = neighborhood.neighbors.length - 3; int bestCentroidOffset = centroidIdx; assert centroidIdx >= 0 && centroidIdx < centroids.length; float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]); - for (int i = 0; i < neighborhood.neighbors.length; i++) { - int offset = neighborhood.neighbors[i]; - // float score = neighborhood.scores[i]; - assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset; + int i = 0; + for (; i < limit; i += 4) { if (minDsq < neighborhood.maxIntraDistance) { // if the distance found is smaller than the maximum intra-cluster distance // we don't consider it for further re-assignment return bestCentroidOffset; } + ESVectorUtil.squareDistanceBulk( + vector, + centroids[neighborhood.neighbors[i]], + centroids[neighborhood.neighbors[i + 1]], + centroids[neighborhood.neighbors[i + 2]], + centroids[neighborhood.neighbors[i + 3]], + distances + ); + for (int j = 0; j < distances.length; j++) { + float dsq = distances[j]; + if (dsq < minDsq) { + minDsq = dsq; + bestCentroidOffset = neighborhood.neighbors[i + j]; + } + } + } + for (; i < neighborhood.neighbors.length; i++) { + if (minDsq < neighborhood.maxIntraDistance) { + // if the distance found is smaller than the maximum intra-cluster distance + // we don't consider it for further re-assignment + return bestCentroidOffset; + } + int offset = neighborhood.neighbors[i]; + // float score = neighborhood.scores[i]; + assert offset >= 0 && offset < centroids.length : "Invalid neighbor offset: " + offset; // compute the distance to the centroid float dsq = VectorUtil.squareDistance(vector, centroids[offset]); if (dsq < minDsq) { @@ -137,10 +167,22 @@ private static int getBestCentroidFromNeighbours(float[][] centroids, float[] ve return bestCentroidOffset; } - private static int getBestCentroid(float[][] centroids, float[] vector) { + private static int getBestCentroid(float[][] centroids, float[] vector, float[] distances) { + final int limit = centroids.length - 3; int bestCentroidOffset = 0; float minDsq = Float.MAX_VALUE; - for (int i = 0; i < centroids.length; i++) { + int i = 0; + for (; i < limit; i += 4) { + ESVectorUtil.squareDistanceBulk(vector, centroids[i], centroids[i + 1], centroids[i + 2], centroids[i + 3], distances); + for (int j = 0; j < distances.length; j++) { + float dsq = distances[j]; + if (dsq < minDsq) { + minDsq = dsq; + bestCentroidOffset = i + j; + } + } + } + for (; i < centroids.length; i++) { float dsq = VectorUtil.squareDistance(vector, centroids[i]); if (dsq < minDsq) { minDsq = dsq; @@ -157,9 +199,20 @@ private NeighborHood[] computeNeighborhoods(float[][] centers, int clustersPerNe for (int i = 0; i < k; i++) { neighborQueues[i] = new NeighborQueue(clustersPerNeighborhood, true); } + final float[] scores = new float[4]; + final int limit = k - 3; for (int i = 0; i < k - 1; i++) { - for (int j = i + 1; j < k; j++) { - float dsq = VectorUtil.squareDistance(centers[i], centers[j]); + float[] center = centers[i]; + int j = i + 1; + for (; j < limit; j += 4) { + ESVectorUtil.squareDistanceBulk(center, centers[j], centers[j + 1], centers[j + 2], centers[j + 3], scores); + for (int h = 0; h < 4; h++) { + neighborQueues[j + h].insertWithOverflow(i, scores[h]); + neighborQueues[i].insertWithOverflow(j + h, scores[h]); + } + } + for (; j < k; j++) { + float dsq = VectorUtil.squareDistance(center, centers[j]); neighborQueues[j].insertWithOverflow(i, dsq); neighborQueues[i].insertWithOverflow(j, dsq); } @@ -208,6 +261,7 @@ private void assignSpilled( float[][] centroids = kmeansIntermediate.centroids(); float[] diffs = new float[vectors.dimension()]; + final float[] distances = new float[4]; for (int i = 0; i < vectors.size(); i++) { float[] vector = vectors.vectorValue(i); @@ -219,33 +273,66 @@ private void assignSpilled( if (vectorCentroidDist > SOAR_MIN_DISTANCE) { for (int j = 0; j < vectors.dimension(); j++) { - float diff = vector[j] - currentCentroid[j]; - diffs[j] = diff; + diffs[j] = vector[j] - currentCentroid[j]; } } - int bestAssignment = -1; - float minSoar = Float.MAX_VALUE; - int centroidCount = centroids.length; - IntToIntFunction centroidOrds = c -> c; + final int centroidCount; + final IntToIntFunction centroidOrds; if (neighborhoods != null) { assert neighborhoods[currAssignment] != null; NeighborHood neighborhood = neighborhoods[currAssignment]; centroidCount = neighborhood.neighbors.length; centroidOrds = c -> neighborhood.neighbors[c]; + } else { + centroidCount = centroids.length - 1; + centroidOrds = c -> c < currAssignment ? c : c + 1; // skip the current centroid } - for (int j = 0; j < centroidCount; j++) { - int centroidOrd = centroidOrds.apply(j); - if (centroidOrd == currAssignment) { - continue; // skip the current assignment + final int limit = centroidCount - 3; + int bestAssignment = -1; + float minSoar = Float.MAX_VALUE; + int j = 0; + for (; j < limit; j += 4) { + if (vectorCentroidDist > SOAR_MIN_DISTANCE) { + ESVectorUtil.soarDistanceBulk( + vector, + centroids[centroidOrds.apply(j)], + centroids[centroidOrds.apply(j + 1)], + centroids[centroidOrds.apply(j + 2)], + centroids[centroidOrds.apply(j + 3)], + diffs, + soarLambda, + vectorCentroidDist, + distances + ); + } else { + // if the vector is very close to the centroid, we look for the second-nearest centroid + ESVectorUtil.squareDistanceBulk( + vector, + centroids[centroidOrds.apply(j)], + centroids[centroidOrds.apply(j + 1)], + centroids[centroidOrds.apply(j + 2)], + centroids[centroidOrds.apply(j + 3)], + distances + ); } - float[] centroid = centroids[centroidOrd]; + for (int k = 0; k < distances.length; k++) { + float soar = distances[k]; + if (soar < minSoar) { + minSoar = soar; + bestAssignment = centroidOrds.apply(j + k); + } + } + } + + for (; j < centroidCount; j++) { + int centroidOrd = centroidOrds.apply(j); float soar; if (vectorCentroidDist > SOAR_MIN_DISTANCE) { - soar = ESVectorUtil.soarDistance(vector, centroid, diffs, soarLambda, vectorCentroidDist); + soar = ESVectorUtil.soarDistance(vector, centroids[centroidOrd], diffs, soarLambda, vectorCentroidDist); } else { // if the vector is very close to the centroid, we look for the second-nearest centroid - soar = VectorUtil.squareDistance(vector, centroid); + soar = VectorUtil.squareDistance(vector, centroids[centroidOrd]); } if (soar < minSoar) { minSoar = soar;