Skip to content

Commit 6578b9e

Browse files
authored
Speed up hierarchical k-means by computing distances in bulk (#132384)
This commit adds on-heap bulk distance computations. In particular, it implements the methods ESVectorUtil#squareDistanceBulk and ``ESVectorUtil#soarDistanceBulk` to compute four distances in one method call.
1 parent fcf0408 commit 6578b9e

File tree

7 files changed

+529
-25
lines changed

7 files changed

+529
-25
lines changed
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.benchmark.vector;
10+
11+
import org.apache.lucene.util.VectorUtil;
12+
import org.apache.lucene.util.quantization.OptimizedScalarQuantizer;
13+
import org.elasticsearch.common.logging.LogConfigurator;
14+
import org.elasticsearch.simdvec.ESVectorUtil;
15+
import org.openjdk.jmh.annotations.Benchmark;
16+
import org.openjdk.jmh.annotations.BenchmarkMode;
17+
import org.openjdk.jmh.annotations.Fork;
18+
import org.openjdk.jmh.annotations.Measurement;
19+
import org.openjdk.jmh.annotations.Mode;
20+
import org.openjdk.jmh.annotations.OutputTimeUnit;
21+
import org.openjdk.jmh.annotations.Param;
22+
import org.openjdk.jmh.annotations.Scope;
23+
import org.openjdk.jmh.annotations.Setup;
24+
import org.openjdk.jmh.annotations.State;
25+
import org.openjdk.jmh.annotations.Warmup;
26+
import org.openjdk.jmh.infra.Blackhole;
27+
28+
import java.io.IOException;
29+
import java.util.Random;
30+
import java.util.concurrent.TimeUnit;
31+
32+
@BenchmarkMode(Mode.Throughput)
33+
@OutputTimeUnit(TimeUnit.MILLISECONDS)
34+
@State(Scope.Benchmark)
35+
// first iteration is complete garbage, so make sure we really warmup
36+
@Warmup(iterations = 4, time = 1)
37+
// real iterations. not useful to spend tons of time here, better to fork more
38+
@Measurement(iterations = 5, time = 1)
39+
// engage some noise reduction
40+
@Fork(value = 1)
41+
public class DistanceBulkBenchmark {
42+
43+
static {
44+
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
45+
}
46+
47+
@Param({ "384", "782", "1024" })
48+
int dims;
49+
50+
int length;
51+
52+
int numVectors = 4 * 100;
53+
int numQueries = 10;
54+
55+
float[][] vectors;
56+
float[][] queries;
57+
float[] distances = new float[4];
58+
59+
@Setup
60+
public void setup() throws IOException {
61+
Random random = new Random(123);
62+
63+
this.length = OptimizedScalarQuantizer.discretize(dims, 64) / 8;
64+
65+
vectors = new float[numVectors][dims];
66+
for (float[] vector : vectors) {
67+
for (int i = 0; i < dims; i++) {
68+
vector[i] = random.nextFloat();
69+
}
70+
}
71+
72+
queries = new float[numQueries][dims];
73+
for (float[] query : queries) {
74+
for (int i = 0; i < dims; i++) {
75+
query[i] = random.nextFloat();
76+
}
77+
}
78+
}
79+
80+
@Benchmark
81+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
82+
public void squareDistance(Blackhole bh) {
83+
for (int j = 0; j < numQueries; j++) {
84+
float[] query = queries[j];
85+
for (int i = 0; i < numVectors; i++) {
86+
float[] vector = vectors[i];
87+
float distance = VectorUtil.squareDistance(query, vector);
88+
bh.consume(distance);
89+
}
90+
}
91+
}
92+
93+
@Benchmark
94+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
95+
public void soarDistance(Blackhole bh) {
96+
for (int j = 0; j < numQueries; j++) {
97+
float[] query = queries[j];
98+
for (int i = 0; i < numVectors; i++) {
99+
float[] vector = vectors[i];
100+
float distance = ESVectorUtil.soarDistance(query, vector, vector, 1.0f, 1.0f);
101+
bh.consume(distance);
102+
}
103+
}
104+
}
105+
106+
@Benchmark
107+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
108+
public void squareDistanceBulk(Blackhole bh) {
109+
for (int j = 0; j < numQueries; j++) {
110+
float[] query = queries[j];
111+
for (int i = 0; i < numVectors; i += 4) {
112+
ESVectorUtil.squareDistanceBulk(query, vectors[i], vectors[i + 1], vectors[i + 2], vectors[i + 3], distances);
113+
for (float distance : distances) {
114+
bh.consume(distance);
115+
}
116+
117+
}
118+
}
119+
}
120+
121+
@Benchmark
122+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
123+
public void soarDistanceBulk(Blackhole bh) {
124+
for (int j = 0; j < numQueries; j++) {
125+
float[] query = queries[j];
126+
for (int i = 0; i < numVectors; i += 4) {
127+
ESVectorUtil.soarDistanceBulk(
128+
query,
129+
vectors[i],
130+
vectors[i + 1],
131+
vectors[i + 2],
132+
vectors[i + 3],
133+
vectors[i],
134+
1.0f,
135+
1.0f,
136+
distances
137+
);
138+
for (float distance : distances) {
139+
bh.consume(distance);
140+
}
141+
142+
}
143+
}
144+
}
145+
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,79 @@ public static int quantizeVectorWithIntervals(float[] vector, int[] destination,
293293
}
294294
return IMPL.quantizeVectorWithIntervals(vector, destination, lowInterval, upperInterval, bit);
295295
}
296+
297+
/**
298+
* Bulk computation of square distances between a query vector and four vectors.Result is stored in the provided distances array.
299+
*
300+
* @param q the query vector
301+
* @param v0 the first vector
302+
* @param v1 the second vector
303+
* @param v2 the third vector
304+
* @param v3 the fourth vector
305+
* @param distances an array to store the computed square distances, must have length 4
306+
*
307+
* @throws IllegalArgumentException if the dimensions of the vectors do not match or if the distances array does not have length 4
308+
*/
309+
public static void squareDistanceBulk(float[] q, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
310+
if (q.length != v0.length) {
311+
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v0.length);
312+
}
313+
if (q.length != v1.length) {
314+
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v1.length);
315+
}
316+
if (q.length != v2.length) {
317+
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v2.length);
318+
}
319+
if (q.length != v3.length) {
320+
throw new IllegalArgumentException("vector dimensions differ: " + q.length + "!=" + v3.length);
321+
}
322+
if (distances.length != 4) {
323+
throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length);
324+
}
325+
IMPL.squareDistanceBulk(q, v0, v1, v2, v3, distances);
326+
}
327+
328+
/**
329+
* Bulk computation of the soar distance for a vector to four centroids
330+
* @param v1 the vector
331+
* @param c0 the first centroid
332+
* @param c1 the second centroid
333+
* @param c2 the third centroid
334+
* @param c3 the fourth centroid
335+
* @param originalResidual the residual with the actually nearest centroid
336+
* @param soarLambda the lambda parameter
337+
* @param rnorm distance to the nearest centroid
338+
* @param distances an array to store the computed soar distances, must have length 4
339+
*/
340+
public static void soarDistanceBulk(
341+
float[] v1,
342+
float[] c0,
343+
float[] c1,
344+
float[] c2,
345+
float[] c3,
346+
float[] originalResidual,
347+
float soarLambda,
348+
float rnorm,
349+
float[] distances
350+
) {
351+
if (v1.length != c0.length) {
352+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c0.length);
353+
}
354+
if (v1.length != c1.length) {
355+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c1.length);
356+
}
357+
if (v1.length != c2.length) {
358+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c2.length);
359+
}
360+
if (v1.length != c3.length) {
361+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + c3.length);
362+
}
363+
if (v1.length != originalResidual.length) {
364+
throw new IllegalArgumentException("vector dimensions differ: " + v1.length + "!=" + originalResidual.length);
365+
}
366+
if (distances.length != 4) {
367+
throw new IllegalArgumentException("distances array must have length 4, but was: " + distances.length);
368+
}
369+
IMPL.soarDistanceBulk(v1, c0, c1, c2, c3, originalResidual, soarLambda, rnorm, distances);
370+
}
296371
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,4 +293,30 @@ public int quantizeVectorWithIntervals(float[] vector, int[] destination, float
293293
}
294294
return sumQuery;
295295
}
296+
297+
@Override
298+
public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
299+
distances[0] = VectorUtil.squareDistance(query, v0);
300+
distances[1] = VectorUtil.squareDistance(query, v1);
301+
distances[2] = VectorUtil.squareDistance(query, v2);
302+
distances[3] = VectorUtil.squareDistance(query, v3);
303+
}
304+
305+
@Override
306+
public void soarDistanceBulk(
307+
float[] v1,
308+
float[] c0,
309+
float[] c1,
310+
float[] c2,
311+
float[] c3,
312+
float[] originalResidual,
313+
float soarLambda,
314+
float rnorm,
315+
float[] distances
316+
) {
317+
distances[0] = soarDistance(v1, c0, originalResidual, soarLambda, rnorm);
318+
distances[1] = soarDistance(v1, c1, originalResidual, soarLambda, rnorm);
319+
distances[2] = soarDistance(v1, c2, originalResidual, soarLambda, rnorm);
320+
distances[3] = soarDistance(v1, c3, originalResidual, soarLambda, rnorm);
321+
}
296322
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,4 +50,17 @@ float calculateOSQLoss(
5050

5151
int quantizeVectorWithIntervals(float[] vector, int[] quantize, float lowInterval, float upperInterval, byte bit);
5252

53+
void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances);
54+
55+
void soarDistanceBulk(
56+
float[] v1,
57+
float[] c0,
58+
float[] c1,
59+
float[] c2,
60+
float[] c3,
61+
float[] originalResidual,
62+
float soarLambda,
63+
float rnorm,
64+
float[] distances
65+
);
5366
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,4 +822,122 @@ public int quantizeVectorWithIntervals(float[] vector, int[] destination, float
822822
}
823823
return sumQuery;
824824
}
825+
826+
@Override
827+
public void squareDistanceBulk(float[] query, float[] v0, float[] v1, float[] v2, float[] v3, float[] distances) {
828+
FloatVector sv0 = FloatVector.zero(FLOAT_SPECIES);
829+
FloatVector sv1 = FloatVector.zero(FLOAT_SPECIES);
830+
FloatVector sv2 = FloatVector.zero(FLOAT_SPECIES);
831+
FloatVector sv3 = FloatVector.zero(FLOAT_SPECIES);
832+
final int limit = FLOAT_SPECIES.loopBound(query.length);
833+
int i = 0;
834+
for (; i < limit; i += FLOAT_SPECIES.length()) {
835+
FloatVector qv = FloatVector.fromArray(FLOAT_SPECIES, query, i);
836+
FloatVector dv0 = FloatVector.fromArray(FLOAT_SPECIES, v0, i);
837+
FloatVector dv1 = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
838+
FloatVector dv2 = FloatVector.fromArray(FLOAT_SPECIES, v2, i);
839+
FloatVector dv3 = FloatVector.fromArray(FLOAT_SPECIES, v3, i);
840+
FloatVector diff0 = qv.sub(dv0);
841+
sv0 = fma(diff0, diff0, sv0);
842+
FloatVector diff1 = qv.sub(dv1);
843+
sv1 = fma(diff1, diff1, sv1);
844+
FloatVector diff2 = qv.sub(dv2);
845+
sv2 = fma(diff2, diff2, sv2);
846+
FloatVector diff3 = qv.sub(dv3);
847+
sv3 = fma(diff3, diff3, sv3);
848+
}
849+
float distance0 = sv0.reduceLanes(VectorOperators.ADD);
850+
float distance1 = sv1.reduceLanes(VectorOperators.ADD);
851+
float distance2 = sv2.reduceLanes(VectorOperators.ADD);
852+
float distance3 = sv3.reduceLanes(VectorOperators.ADD);
853+
854+
for (; i < query.length; i++) {
855+
final float qValue = query[i];
856+
final float diff0 = qValue - v0[i];
857+
final float diff1 = qValue - v1[i];
858+
final float diff2 = qValue - v2[i];
859+
final float diff3 = qValue - v3[i];
860+
distance0 = fma(diff0, diff0, distance0);
861+
distance1 = fma(diff1, diff1, distance1);
862+
distance2 = fma(diff2, diff2, distance2);
863+
distance3 = fma(diff3, diff3, distance3);
864+
}
865+
distances[0] = distance0;
866+
distances[1] = distance1;
867+
distances[2] = distance2;
868+
distances[3] = distance3;
869+
}
870+
871+
@Override
872+
public void soarDistanceBulk(
873+
float[] v1,
874+
float[] c0,
875+
float[] c1,
876+
float[] c2,
877+
float[] c3,
878+
float[] originalResidual,
879+
float soarLambda,
880+
float rnorm,
881+
float[] distances
882+
) {
883+
884+
FloatVector projVec0 = FloatVector.zero(FLOAT_SPECIES);
885+
FloatVector projVec1 = FloatVector.zero(FLOAT_SPECIES);
886+
FloatVector projVec2 = FloatVector.zero(FLOAT_SPECIES);
887+
FloatVector projVec3 = FloatVector.zero(FLOAT_SPECIES);
888+
FloatVector acc0 = FloatVector.zero(FLOAT_SPECIES);
889+
FloatVector acc1 = FloatVector.zero(FLOAT_SPECIES);
890+
FloatVector acc2 = FloatVector.zero(FLOAT_SPECIES);
891+
FloatVector acc3 = FloatVector.zero(FLOAT_SPECIES);
892+
final int limit = FLOAT_SPECIES.loopBound(v1.length);
893+
int i = 0;
894+
for (; i < limit; i += FLOAT_SPECIES.length()) {
895+
FloatVector v1Vec = FloatVector.fromArray(FLOAT_SPECIES, v1, i);
896+
FloatVector c0Vec = FloatVector.fromArray(FLOAT_SPECIES, c0, i);
897+
FloatVector c1Vec = FloatVector.fromArray(FLOAT_SPECIES, c1, i);
898+
FloatVector c2Vec = FloatVector.fromArray(FLOAT_SPECIES, c2, i);
899+
FloatVector c3Vec = FloatVector.fromArray(FLOAT_SPECIES, c3, i);
900+
FloatVector originalResidualVec = FloatVector.fromArray(FLOAT_SPECIES, originalResidual, i);
901+
FloatVector djkVec0 = v1Vec.sub(c0Vec);
902+
FloatVector djkVec1 = v1Vec.sub(c1Vec);
903+
FloatVector djkVec2 = v1Vec.sub(c2Vec);
904+
FloatVector djkVec3 = v1Vec.sub(c3Vec);
905+
projVec0 = fma(djkVec0, originalResidualVec, projVec0);
906+
projVec1 = fma(djkVec1, originalResidualVec, projVec1);
907+
projVec2 = fma(djkVec2, originalResidualVec, projVec2);
908+
projVec3 = fma(djkVec3, originalResidualVec, projVec3);
909+
acc0 = fma(djkVec0, djkVec0, acc0);
910+
acc1 = fma(djkVec1, djkVec1, acc1);
911+
acc2 = fma(djkVec2, djkVec2, acc2);
912+
acc3 = fma(djkVec3, djkVec3, acc3);
913+
}
914+
float proj0 = projVec0.reduceLanes(ADD);
915+
float dsq0 = acc0.reduceLanes(ADD);
916+
float proj1 = projVec1.reduceLanes(ADD);
917+
float dsq1 = acc1.reduceLanes(ADD);
918+
float proj2 = projVec2.reduceLanes(ADD);
919+
float dsq2 = acc2.reduceLanes(ADD);
920+
float proj3 = projVec3.reduceLanes(ADD);
921+
float dsq3 = acc3.reduceLanes(ADD);
922+
// tail
923+
for (; i < v1.length; i++) {
924+
float v = v1[i];
925+
float djk0 = v - c0[i];
926+
float djk1 = v - c1[i];
927+
float djk2 = v - c2[i];
928+
float djk3 = v - c3[i];
929+
proj0 = fma(djk0, originalResidual[i], proj0);
930+
proj1 = fma(djk1, originalResidual[i], proj1);
931+
proj2 = fma(djk2, originalResidual[i], proj2);
932+
proj3 = fma(djk3, originalResidual[i], proj3);
933+
dsq0 = fma(djk0, djk0, dsq0);
934+
dsq1 = fma(djk1, djk1, dsq1);
935+
dsq2 = fma(djk2, djk2, dsq2);
936+
dsq3 = fma(djk3, djk3, dsq3);
937+
}
938+
distances[0] = dsq0 + soarLambda * proj0 * proj0 / rnorm;
939+
distances[1] = dsq1 + soarLambda * proj1 * proj1 / rnorm;
940+
distances[2] = dsq2 + soarLambda * proj2 * proj2 / rnorm;
941+
distances[3] = dsq3 + soarLambda * proj3 * proj3 / rnorm;
942+
}
825943
}

0 commit comments

Comments
 (0)