Skip to content

Commit a156056

Browse files
committed
vector methods
1 parent f6ffd56 commit a156056

File tree

5 files changed

+67
-10
lines changed

5 files changed

+67
-10
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,4 +435,15 @@ public static int indexOf(byte[] bytes, int offset, int length, byte marker) {
435435
Objects.checkFromIndexSize(offset, length, bytes.length);
436436
return IMPL.indexOf(bytes, offset, length, marker);
437437
}
438+
439+
public static void vectorAccumulateAdd(float[] a, float[] b) {
440+
if (a.length != b.length) {
441+
throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length);
442+
}
443+
IMPL.vectorAccumulateAdd(a, b);
444+
}
445+
446+
public static void vectorScalerDivide(float [] a, float b) {
447+
IMPL.vectorScalerDivide(a, b);
448+
}
438449
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
import org.apache.lucene.util.Constants;
1414
import org.apache.lucene.util.VectorUtil;
1515

16+
import java.util.Arrays;
17+
import java.util.concurrent.atomic.AtomicLong;
18+
1619
final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
1720

1821
private static float fma(float a, float b, float c) {
@@ -455,4 +458,16 @@ public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) {
455458
public int indexOf(byte[] bytes, int offset, int length, byte marker) {
456459
return ByteArrayUtils.indexOf(bytes, offset, length, marker);
457460
}
461+
462+
public void vectorScalerDivide(float [] a, float b) {
463+
for (int d = 0; d < a.length; d++) {
464+
a[d] /= b;
465+
}
466+
}
467+
468+
public void vectorAccumulateAdd(float[] a, float[] b) {
469+
for (int d = 0; d < a.length; d++) {
470+
a[d] += b[d];
471+
}
472+
}
458473
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,8 @@ void soarDistanceBulk(
7171
void transposeHalfByte(int[] q, byte[] quantQueryByte);
7272

7373
int indexOf(byte[] bytes, int offset, int length, byte marker);
74+
75+
void vectorScalerDivide(float [] a, float b);
76+
77+
void vectorAccumulateAdd(float[] a, float[] b);
7478
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1140,4 +1140,37 @@ public int indexOf(final byte[] bytes, final int offset, final int length, final
11401140
}
11411141
return -1;
11421142
}
1143+
1144+
@Override
1145+
public void vectorAccumulateAdd(float[] a, float[] b) {
1146+
final VectorSpecies<Float> SPECIES = FloatVector.SPECIES_PREFERRED;
1147+
1148+
int i = 0;
1149+
for (; i < SPECIES.loopBound(a.length); i += SPECIES.length()) {
1150+
FloatVector va = FloatVector.fromArray(SPECIES, a, i);
1151+
FloatVector vb = FloatVector.fromArray(SPECIES, b, i);
1152+
FloatVector vc = va.add(vb);
1153+
vc.intoArray(a, i);
1154+
}
1155+
1156+
for (; i < a.length; i++) {
1157+
a[i] += b[i];
1158+
}
1159+
}
1160+
1161+
@Override
1162+
public void vectorScalerDivide(float[] a, float b) {
1163+
final VectorSpecies<Float> SPECIES = FloatVector.SPECIES_PREFERRED;
1164+
1165+
int i = 0;
1166+
for (; i < SPECIES.loopBound(a.length); i += SPECIES.length()) {
1167+
FloatVector va = FloatVector.fromArray(SPECIES, a, i);
1168+
FloatVector vc = va.div(b);
1169+
vc.intoArray(a, i);
1170+
}
1171+
1172+
for (; i < a.length; i++) {
1173+
a[i] = a[i] / b;
1174+
}
1175+
}
11431176
}

server/src/main/java/org/elasticsearch/index/codec/vectors/cluster/KMeansLocal.java

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -106,25 +106,19 @@ private static boolean stepLloyd(
106106
for (int idx = 0; idx < vectors.size(); idx++) {
107107
final int assignment = assignments[translateOrd.apply(idx)];
108108
if (centroidChanged.get(assignment)) {
109-
float[] centroid = centroids[assignment];
110109
if (centroidCounts[assignment]++ == 0) {
111-
Arrays.fill(centroid, 0.0f);
112-
}
113-
float[] vector = vectors.vectorValue(idx);
114-
for (int d = 0; d < dim; d++) {
115-
centroid[d] += vector[d];
110+
centroids[assignment] = vectors.vectorValue(idx);
111+
continue;
116112
}
113+
ESVectorUtil.vectorAccumulateAdd(centroids[assignment], vectors.vectorValue(idx));
117114
}
118115
}
119116

120117
for (int clusterIdx = 0; clusterIdx < centroids.length; clusterIdx++) {
121118
if (centroidChanged.get(clusterIdx)) {
122119
float count = (float) centroidCounts[clusterIdx];
123120
if (count > 0) {
124-
float[] centroid = centroids[clusterIdx];
125-
for (int d = 0; d < dim; d++) {
126-
centroid[d] /= count;
127-
}
121+
ESVectorUtil.vectorScalerDivide(centroids[clusterIdx], count);
128122
}
129123
}
130124
}

0 commit comments

Comments
 (0)