Skip to content

Commit d68a17a

Browse files
authored
Vectorize BQVectorUtils#packAsBinary (#132923)
Vectorize by using a helper array that defines the shifts we need to apply to the vector elements.
1 parent 624f497 commit d68a17a

File tree

7 files changed

+153
-25
lines changed

7 files changed

+153
-25
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/PackAsBinaryBenchmark.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,4 +82,13 @@ public void packAsBinaryLegacy(Blackhole bh) {
8282
bh.consume(packed);
8383
}
8484
}
85+
86+
@Benchmark
87+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
88+
public void packAsBinaryPanama(Blackhole bh) {
89+
for (int i = 0; i < numVectors; i++) {
90+
BQVectorUtils.packAsBinary(qVectors[i], packed);
91+
bh.consume(packed);
92+
}
93+
}
8594
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,4 +368,17 @@ public static void soarDistanceBulk(
368368
}
369369
IMPL.soarDistanceBulk(v1, c0, c1, c2, c3, originalResidual, soarLambda, rnorm, distances);
370370
}
371+
372+
/**
373+
* Packs the provided int array populated with "0" and "1" values into a byte array.
374+
*
375+
* @param vector the int array to pack, must contain only "0" and "1" values.
376+
* @param packed the byte array to store the packed result, must be large enough to hold the packed data.
377+
*/
378+
public static void packAsBinary(int[] vector, byte[] packed) {
379+
if (packed.length * Byte.SIZE < vector.length) {
380+
throw new IllegalArgumentException("packed array is too small: " + packed.length * Byte.SIZE + " < " + vector.length);
381+
}
382+
IMPL.packAsBinary(vector, packed);
383+
}
371384
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,4 +320,37 @@ public void soarDistanceBulk(
320320
distances[2] = soarDistance(v1, c2, originalResidual, soarLambda, rnorm);
321321
distances[3] = soarDistance(v1, c3, originalResidual, soarLambda, rnorm);
322322
}
323+
324+
@Override
325+
public void packAsBinary(int[] vector, byte[] packed) {
326+
packAsBinaryImpl(vector, packed);
327+
}
328+
329+
public static void packAsBinaryImpl(int[] vector, byte[] packed) {
330+
int limit = vector.length - 7;
331+
int i = 0;
332+
int index = 0;
333+
for (; i < limit; i += 8, index++) {
334+
assert vector[i] == 0 || vector[i] == 1;
335+
assert vector[i + 1] == 0 || vector[i + 1] == 1;
336+
assert vector[i + 2] == 0 || vector[i + 2] == 1;
337+
assert vector[i + 3] == 0 || vector[i + 3] == 1;
338+
assert vector[i + 4] == 0 || vector[i + 4] == 1;
339+
assert vector[i + 5] == 0 || vector[i + 5] == 1;
340+
assert vector[i + 6] == 0 || vector[i + 6] == 1;
341+
assert vector[i + 7] == 0 || vector[i + 7] == 1;
342+
int result = vector[i] << 7 | (vector[i + 1] << 6) | (vector[i + 2] << 5) | (vector[i + 3] << 4) | (vector[i + 4] << 3)
343+
| (vector[i + 5] << 2) | (vector[i + 6] << 1) | (vector[i + 7]);
344+
packed[index] = (byte) result;
345+
}
346+
if (i == vector.length) {
347+
return;
348+
}
349+
byte result = 0;
350+
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
351+
assert vector[i] == 0 || vector[i] == 1;
352+
result |= (byte) ((vector[i] & 1) << j);
353+
}
354+
packed[index] = result;
355+
}
323356
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,6 @@ void soarDistanceBulk(
6363
float rnorm,
6464
float[] distances
6565
);
66+
67+
void packAsBinary(int[] vector, byte[] packed);
6668
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@
2222
import org.apache.lucene.util.Constants;
2323

2424
import static jdk.incubator.vector.VectorOperators.ADD;
25+
import static jdk.incubator.vector.VectorOperators.LSHL;
2526
import static jdk.incubator.vector.VectorOperators.MAX;
2627
import static jdk.incubator.vector.VectorOperators.MIN;
28+
import static jdk.incubator.vector.VectorOperators.OR;
2729

2830
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
2931

@@ -942,4 +944,81 @@ public void soarDistanceBulk(
942944
distances[2] = dsq2 + soarLambda * proj2 * proj2 / rnorm;
943945
distances[3] = dsq3 + soarLambda * proj3 * proj3 / rnorm;
944946
}
947+
948+
private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
949+
private static final IntVector SHIFTS_256;
950+
private static final IntVector HIGH_SHIFTS_128;
951+
private static final IntVector LOW_SHIFTS_128;
952+
static {
953+
final int[] shifts = new int[] { 7, 6, 5, 4, 3, 2, 1, 0 };
954+
if (VECTOR_BITSIZE == 128) {
955+
HIGH_SHIFTS_128 = IntVector.fromArray(INT_SPECIES_128, shifts, 0);
956+
LOW_SHIFTS_128 = IntVector.fromArray(INT_SPECIES_128, shifts, INT_SPECIES_128.length());
957+
SHIFTS_256 = null;
958+
} else {
959+
SHIFTS_256 = IntVector.fromArray(INT_SPECIES_256, shifts, 0);
960+
HIGH_SHIFTS_128 = null;
961+
LOW_SHIFTS_128 = null;
962+
}
963+
}
964+
private static final int[] SHIFTS = new int[] { 7, 6, 5, 4, 3, 2, 1, 0 };
965+
966+
@Override
967+
public void packAsBinary(int[] vector, byte[] packed) {
968+
// 128 / 32 == 4
969+
if (vector.length >= 8 && HAS_FAST_INTEGER_VECTORS) {
970+
// TODO: can we optimize for >= 512?
971+
if (VECTOR_BITSIZE >= 256) {
972+
packAsBinary256(vector, packed);
973+
return;
974+
} else if (VECTOR_BITSIZE == 128) {
975+
packAsBinary128(vector, packed);
976+
return;
977+
}
978+
}
979+
DefaultESVectorUtilSupport.packAsBinaryImpl(vector, packed);
980+
}
981+
982+
private void packAsBinary256(int[] vector, byte[] packed) {
983+
final int limit = INT_SPECIES_256.loopBound(vector.length);
984+
int i = 0;
985+
int index = 0;
986+
for (; i < limit; i += INT_SPECIES_256.length(), index++) {
987+
IntVector v = IntVector.fromArray(INT_SPECIES_256, vector, i);
988+
int result = v.lanewise(LSHL, SHIFTS_256).reduceLanes(OR);
989+
packed[index] = (byte) result;
990+
}
991+
if (i == vector.length) {
992+
return; // all done
993+
}
994+
byte result = 0;
995+
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
996+
assert vector[i] == 0 || vector[i] == 1;
997+
result |= (byte) ((vector[i] & 1) << j);
998+
}
999+
packed[index] = result;
1000+
}
1001+
1002+
private void packAsBinary128(int[] vector, byte[] packed) {
1003+
final int limit = INT_SPECIES_128.loopBound(vector.length) - INT_SPECIES_128.length();
1004+
int i = 0;
1005+
int index = 0;
1006+
for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) {
1007+
IntVector v = IntVector.fromArray(INT_SPECIES_128, vector, i);
1008+
var v1 = v.lanewise(LSHL, HIGH_SHIFTS_128);
1009+
v = IntVector.fromArray(INT_SPECIES_128, vector, i + INT_SPECIES_128.length());
1010+
var v2 = v.lanewise(LSHL, LOW_SHIFTS_128);
1011+
int result = v1.lanewise(OR, v2).reduceLanes(OR);
1012+
packed[index] = (byte) result;
1013+
}
1014+
if (i == vector.length) {
1015+
return; // all done
1016+
}
1017+
byte result = 0;
1018+
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
1019+
assert vector[i] == 0 || vector[i] == 1;
1020+
result |= (byte) ((vector[i] & 1) << j);
1021+
}
1022+
packed[index] = result;
1023+
}
9451024
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
package org.elasticsearch.simdvec;
1111

12+
import org.elasticsearch.index.codec.vectors.BQVectorUtils;
1213
import org.elasticsearch.index.codec.vectors.OptimizedScalarQuantizer;
1314
import org.elasticsearch.simdvec.internal.vectorization.BaseVectorizationTests;
1415
import org.elasticsearch.simdvec.internal.vectorization.ESVectorizationProvider;
@@ -355,6 +356,20 @@ public void testSoarDistanceBulk() {
355356
assertArrayEquals(expectedDistances, panamaDistances, deltaEps);
356357
}
357358

359+
public void testPackAsBinary() {
360+
int dims = randomIntBetween(16, 2048);
361+
int[] toPack = new int[dims];
362+
for (int i = 0; i < dims; i++) {
363+
toPack[i] = randomInt(1);
364+
}
365+
int length = BQVectorUtils.discretize(dims, 64) / 8;
366+
byte[] packed = new byte[length];
367+
byte[] packedLegacy = new byte[length];
368+
defaultedProvider.getVectorUtilSupport().packAsBinary(toPack, packedLegacy);
369+
defOrPanamaProvider.getVectorUtilSupport().packAsBinary(toPack, packed);
370+
assertArrayEquals(packedLegacy, packed);
371+
}
372+
358373
private float[] generateRandomVector(int size) {
359374
float[] vector = new float[size];
360375
for (int i = 0; i < size; ++i) {

server/src/main/java/org/elasticsearch/index/codec/vectors/BQVectorUtils.java

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.util.ArrayUtil;
2323
import org.apache.lucene.util.BitUtil;
2424
import org.apache.lucene.util.VectorUtil;
25+
import org.elasticsearch.simdvec.ESVectorUtil;
2526

2627
/** Utility class for vector quantization calculations */
2728
public class BQVectorUtils {
@@ -55,31 +56,7 @@ public static void packAsBinaryLegacy(int[] vector, byte[] packed) {
5556
}
5657

5758
public static void packAsBinary(int[] vector, byte[] packed) {
58-
int limit = vector.length - 7;
59-
int i = 0;
60-
int index = 0;
61-
for (; i < limit; i += 8, index++) {
62-
assert vector[i] == 0 || vector[i] == 1;
63-
assert vector[i + 1] == 0 || vector[i + 1] == 1;
64-
assert vector[i + 2] == 0 || vector[i + 2] == 1;
65-
assert vector[i + 3] == 0 || vector[i + 3] == 1;
66-
assert vector[i + 4] == 0 || vector[i + 4] == 1;
67-
assert vector[i + 5] == 0 || vector[i + 5] == 1;
68-
assert vector[i + 6] == 0 || vector[i + 6] == 1;
69-
assert vector[i + 7] == 0 || vector[i + 7] == 1;
70-
int result = vector[i] << 7 | (vector[i + 1] << 6) | (vector[i + 2] << 5) | (vector[i + 3] << 4) | (vector[i + 4] << 3)
71-
| (vector[i + 5] << 2) | (vector[i + 6] << 1) | (vector[i + 7]);
72-
packed[index] = (byte) result;
73-
}
74-
if (i == vector.length) {
75-
return;
76-
}
77-
byte result = 0;
78-
for (int j = 7; j >= 0 && i < vector.length; i++, j--) {
79-
assert vector[i] == 0 || vector[i] == 1;
80-
result |= (byte) ((vector[i] & 1) << j);
81-
}
82-
packed[index] = result;
59+
ESVectorUtil.packAsBinary(vector, packed);
8360
}
8461

8562
public static int discretize(int value, int bucket) {

0 commit comments

Comments
 (0)