Skip to content

Commit 5b3863e

Browse files
committed
Vectorize BQVectorUtils#packAsBinary
1 parent 26ffd7f commit 5b3863e

File tree

7 files changed

+141
-25
lines changed

7 files changed

+141
-25
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/PackAsBinaryBenchmark.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,13 @@ public void packAsBinaryLegacy(Blackhole bh) {
8282
bh.consume(packed);
8383
}
8484
}
85+
86+
@Benchmark
87+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
88+
public void packAsBinaryPanama(Blackhole bh) {
89+
for (int i = 0; i < numVectors; i++) {
90+
BQVectorUtils.packAsBinary(qVectors[i], packed);
91+
bh.consume(packed);
92+
}
93+
}
8594
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,4 +368,17 @@ public static void soarDistanceBulk(
368368
}
369369
IMPL.soarDistanceBulk(v1, c0, c1, c2, c3, originalResidual, soarLambda, rnorm, distances);
370370
}
371+
372+
/**
373+
* Packs the provided int array populated with "0" and "1" values into a byte array.
374+
*
375+
* @param vector the int array to pack, must contain only "0" and "1" values.
376+
* @param packed the byte array to store the packed result, must be large enough to hold the packed data.
377+
*/
378+
public static void packAsBinary(int[] vector, byte[] packed) {
379+
if (packed.length * Byte.SIZE < vector.length) {
380+
throw new IllegalArgumentException("packed array is too small: " + packed.length * Byte.SIZE + " < " + vector.length);
381+
}
382+
IMPL.packAsBinary(vector, packed);
383+
}
371384
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,37 @@ public void soarDistanceBulk(
320320
distances[2] = soarDistance(v1, c2, originalResidual, soarLambda, rnorm);
321321
distances[3] = soarDistance(v1, c3, originalResidual, soarLambda, rnorm);
322322
}
323+
324+
@Override
325+
public void packAsBinary(int[] vector, byte[] packed) {
326+
packAsBinaryImpl(vector, packed);
327+
}
328+
329+
public static void packAsBinaryImpl(int[] vector, byte[] packed) {
330+
int limit = vector.length - 7;
331+
int i = 0;
332+
int index = 0;
333+
for (; i < limit; i += 8, index++) {
334+
assert vector[i] == 0 || vector[i] == 1;
335+
assert vector[i + 1] == 0 || vector[i + 1] == 1;
336+
assert vector[i + 2] == 0 || vector[i + 2] == 1;
337+
assert vector[i + 3] == 0 || vector[i + 3] == 1;
338+
assert vector[i + 4] == 0 || vector[i + 4] == 1;
339+
assert vector[i + 5] == 0 || vector[i + 5] == 1;
340+
assert vector[i + 6] == 0 || vector[i + 6] == 1;
341+
assert vector[i + 7] == 0 || vector[i + 7] == 1;
342+
int result = vector[i] << 7 | (vector[i + 1] << 6) | (vector[i + 2] << 5) | (vector[i + 3] << 4) | (vector[i + 4] << 3)
343+
| (vector[i + 5] << 2) | (vector[i + 6] << 1) | (vector[i + 7]);
344+
packed[index] = (byte) result;
345+
}
346+
if (i == vector.length) {
347+
return;
348+
}
349+
byte result = 0;
350+
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
351+
assert vector[i] == 0 || vector[i] == 1;
352+
result |= (byte) ((vector[i] & 1) << j);
353+
}
354+
packed[index] = result;
355+
}
323356
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,6 @@ void soarDistanceBulk(
6363
float rnorm,
6464
float[] distances
6565
);
66+
67+
void packAsBinary(int[] vector, byte[] packed);
6668
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import org.apache.lucene.util.Constants;
2323

2424
import static jdk.incubator.vector.VectorOperators.ADD;
25+
import static jdk.incubator.vector.VectorOperators.LSHL;
2526
import static jdk.incubator.vector.VectorOperators.MAX;
2627
import static jdk.incubator.vector.VectorOperators.MIN;
28+
import static jdk.incubator.vector.VectorOperators.OR;
2729

2830
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
2931

@@ -942,4 +944,69 @@ public void soarDistanceBulk(
942944
distances[2] = dsq2 + soarLambda * proj2 * proj2 / rnorm;
943945
distances[3] = dsq3 + soarLambda * proj3 * proj3 / rnorm;
944946
}
947+
948+
private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
949+
private static final int[] SHIFTS = new int[] { 7, 6, 5, 4, 3, 2, 1, 0 };
950+
951+
@Override
952+
public void packAsBinary(int[] vector, byte[] packed) {
953+
// 128 / 32 == 4
954+
if (vector.length >= 8 && HAS_FAST_INTEGER_VECTORS) {
955+
// TODO: can we optimize for >= 512?
956+
if (VECTOR_BITSIZE >= 256) {
957+
packAsBinary256(vector, packed);
958+
return;
959+
} else if (VECTOR_BITSIZE == 128) {
960+
packAsBinary128(vector, packed);
961+
return;
962+
}
963+
}
964+
DefaultESVectorUtilSupport.packAsBinaryImpl(vector, packed);
965+
}
966+
967+
private void packAsBinary256(int[] vector, byte[] packed) {
968+
final int limit = INT_SPECIES_256.loopBound(vector.length);
969+
int i = 0;
970+
int index = 0;
971+
IntVector shifts = IntVector.fromArray(INT_SPECIES_256, SHIFTS, 0);
972+
for (; i < limit; i += INT_SPECIES_256.length(), index++) {
973+
IntVector v = IntVector.fromArray(INT_SPECIES_256, vector, i);
974+
int result = v.lanewise(LSHL, shifts).reduceLanes(OR);
975+
packed[index] = (byte) result;
976+
}
977+
if (i == vector.length) {
978+
return; // all done
979+
}
980+
byte result = 0;
981+
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
982+
assert vector[i] == 0 || vector[i] == 1;
983+
result |= (byte) ((vector[i] & 1) << j);
984+
}
985+
packed[index] = result;
986+
}
987+
988+
private void packAsBinary128(int[] vector, byte[] packed) {
989+
final int limit = INT_SPECIES_128.loopBound(vector.length) - INT_SPECIES_128.length();
990+
int i = 0;
991+
int index = 0;
992+
IntVector highShifts = IntVector.fromArray(INT_SPECIES_128, SHIFTS, 0);
993+
IntVector lowShifts = IntVector.fromArray(INT_SPECIES_128, SHIFTS, INT_SPECIES_128.length());
994+
for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) {
995+
IntVector v = IntVector.fromArray(INT_SPECIES_128, vector, i);
996+
var v1 = v.lanewise(LSHL, highShifts);
997+
v = IntVector.fromArray(INT_SPECIES_128, vector, i + INT_SPECIES_128.length());
998+
var v2 = v.lanewise(LSHL, lowShifts);
999+
int result = v1.lanewise(OR, v2).reduceLanes(OR);
1000+
packed[index] = (byte) result;
1001+
}
1002+
if (i == vector.length) {
1003+
return; // all done
1004+
}
1005+
byte result = 0;
1006+
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
1007+
assert vector[i] == 0 || vector[i] == 1;
1008+
result |= (byte) ((vector[i] & 1) << j);
1009+
}
1010+
packed[index] = result;
1011+
}
9451012
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
package org.elasticsearch.simdvec;
1111

12+
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
1213
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
1314
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
1415
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
@@ -355,6 +356,20 @@ public void testSoarDistanceBulk() {
355356
assertArrayEquals(expectedDistances, panamaDistances, deltaEps);
356357
}
357358

359+
public void testPackAsBinary() {
360+
int dims = randomIntBetween(16, 2048);
361+
int[] toPack = new int[dims];
362+
for (int i = 0; i < dims; i++) {
363+
toPack[i] = randomInt(1);
364+
}
365+
int length = BQVectorUtils.discretize(dims, 64) / 8;
366+
byte[] packed = new byte[length];
367+
byte[] packedLegacy = new byte[length];
368+
defaultedProvider.getVectorUtilSupport().packAsBinary(toPack, packedLegacy);
369+
defOrPanamaProvider.getVectorUtilSupport().packAsBinary(toPack, packed);
370+
assertArrayEquals(packedLegacy, packed);
371+
}
372+
358373
private float[] generateRandomVector(int size) {
359374
float[] vector = new float[size];
360375
for (int i = 0; i < size; ++i) {

server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.util.ArrayUtil;
2323
import org.apache.lucene.util.BitUtil;
2424
import org.apache.lucene.util.VectorUtil;
25+
import org.elasticsearch.simdvec.ESVectorUtil;
2526

2627
/** Utility class for vector quantization calculations */
2728
public class BQVectorUtils {
@@ -55,31 +56,7 @@ public static void packAsBinaryLegacy(int[] vector, byte[] packed) {
5556
}
5657

5758
public static void packAsBinary(int[] vector, byte[] packed) {
58-
int limit = vector.length - 7;
59-
int i = 0;
60-
int index = 0;
61-
for (; i < limit; i += 8, index++) {
62-
assert vector[i] == 0 || vector[i] == 1;
63-
assert vector[i + 1] == 0 || vector[i + 1] == 1;
64-
assert vector[i + 2] == 0 || vector[i + 2] == 1;
65-
assert vector[i + 3] == 0 || vector[i + 3] == 1;
66-
assert vector[i + 4] == 0 || vector[i + 4] == 1;
67-
assert vector[i + 5] == 0 || vector[i + 5] == 1;
68-
assert vector[i + 6] == 0 || vector[i + 6] == 1;
69-
assert vector[i + 7] == 0 || vector[i + 7] == 1;
70-
int result = vector[i] << 7 | (vector[i + 1] << 6) | (vector[i + 2] << 5) | (vector[i + 3] << 4) | (vector[i + 4] << 3)
71-
| (vector[i + 5] << 2) | (vector[i + 6] << 1) | (vector[i + 7]);
72-
packed[index] = (byte) result;
73-
}
74-
if (i == vector.length) {
75-
return;
76-
}
77-
byte result = 0;
78-
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
79-
assert vector[i] == 0 || vector[i] == 1;
80-
result |= (byte) ((vector[i] & 1) << j);
81-
}
82-
packed[index] = result;
59+
ESVectorUtil.packAsBinary(vector, packed);
8360
}
8461

8562
public static int discretize(int value, int bucket) {

0 commit comments

Comments
 (0)