Skip to content

Commit 8c01b67

Browse files
authored
Vectorize BQSpaceUtils#transposeHalfByte (elastic#132935)
1 parent 8a278ed commit 8c01b67

File tree

7 files changed

+197
-42
lines changed

7 files changed

+197
-42
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/TransposeHalfByteBenchmark.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,4 +83,13 @@ public void transposeHalfByteLegacy(Blackhole bh) {
8383
bh.consume(packed);
8484
}
8585
}
86+
87+
@Benchmark
88+
@Fork(jvmArgsPrepend = { "--add-modules=jdk.incubator.vector" })
89+
public void transposeHalfBytePanama(Blackhole bh) {
90+
for (int i = 0; i < numVectors; i++) {
91+
BQSpaceUtils.transposeHalfByte(qVectors[i], packed);
92+
bh.consume(packed);
93+
}
94+
}
8695
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,4 +381,22 @@ public static void packAsBinary(int[] vector, byte[] packed) {
381381
}
382382
IMPL.packAsBinary(vector, packed);
383383
}
384+
385+
/**
386+
* The idea here is to organize the query vector bits such that the first bit
387+
* of every dimension is in the first set dimensions bits, or (dimensions/8) bytes. The second,
388+
* third, and fourth bits are in the second, third, and fourth set of dimensions bits,
389+
* respectively. This allows for direct bitwise comparisons with the stored index vectors through
390+
* summing the bitwise results with the relative required bit shifts.
391+
*
392+
* @param q the query vector, assumed to be half-byte quantized with values between 0 and 15
393+
* @param quantQueryByte the byte array to store the transposed query vector.
394+
*
395+
**/
396+
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
397+
if (quantQueryByte.length * Byte.SIZE < 4 * q.length) {
398+
throw new IllegalArgumentException("packed array is too small: " + quantQueryByte.length * Byte.SIZE + " < " + 4 * q.length);
399+
}
400+
IMPL.transposeHalfByte(q, quantQueryByte);
401+
}
384402
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,4 +353,54 @@ public static void packAsBinaryImpl(int[] vector, byte[] packed) {
353353
}
354354
packed[index] = result;
355355
}
356+
357+
@Override
358+
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
359+
transposeHalfByteImpl(q, quantQueryByte);
360+
}
361+
362+
public static void transposeHalfByteImpl(int[] q, byte[] quantQueryByte) {
363+
int limit = q.length - 7;
364+
int i = 0;
365+
int index = 0;
366+
for (; i < limit; i += 8, index++) {
367+
assert q[i] >= 0 && q[i] <= 15;
368+
assert q[i + 1] >= 0 && q[i + 1] <= 15;
369+
assert q[i + 2] >= 0 && q[i + 2] <= 15;
370+
assert q[i + 3] >= 0 && q[i + 3] <= 15;
371+
assert q[i + 4] >= 0 && q[i + 4] <= 15;
372+
assert q[i + 5] >= 0 && q[i + 5] <= 15;
373+
assert q[i + 6] >= 0 && q[i + 6] <= 15;
374+
assert q[i + 7] >= 0 && q[i + 7] <= 15;
375+
int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
376+
+ 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
377+
int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
378+
<< 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
379+
int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
380+
<< 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
381+
int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
382+
| ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
383+
quantQueryByte[index] = (byte) lowerByte;
384+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
385+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
386+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
387+
}
388+
if (i == q.length) {
389+
return; // all done
390+
}
391+
int lowerByte = 0;
392+
int lowerMiddleByte = 0;
393+
int upperMiddleByte = 0;
394+
int upperByte = 0;
395+
for (int j = 7; i < q.length; j--, i++) {
396+
lowerByte |= (q[i] & 1) << j;
397+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
398+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
399+
upperByte |= ((q[i] >> 3) & 1) << j;
400+
}
401+
quantQueryByte[index] = (byte) lowerByte;
402+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
403+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
404+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
405+
}
356406
}

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,6 @@ void soarDistanceBulk(
6565
);
6666

6767
void packAsBinary(int[] vector, byte[] packed);
68+
69+
void transposeHalfByte(int[] q, byte[] quantQueryByte);
6870
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import org.apache.lucene.util.Constants;
2323

2424
import static jdk.incubator.vector.VectorOperators.ADD;
25+
import static jdk.incubator.vector.VectorOperators.ASHR;
2526
import static jdk.incubator.vector.VectorOperators.LSHL;
2627
import static jdk.incubator.vector.VectorOperators.MAX;
2728
import static jdk.incubator.vector.VectorOperators.MIN;
@@ -1021,4 +1022,104 @@ private void packAsBinary128(int[] vector, byte[] packed) {
10211022
}
10221023
packed[index] = result;
10231024
}
1025+
1026+
@Override
1027+
public void transposeHalfByte(int[] q, byte[] quantQueryByte) {
1028+
// 128 / 32 == 4
1029+
if (q.length >= 8 && HAS_FAST_INTEGER_VECTORS) {
1030+
if (VECTOR_BITSIZE >= 256) {
1031+
transposeHalfByte256(q, quantQueryByte);
1032+
return;
1033+
} else if (VECTOR_BITSIZE == 128) {
1034+
transposeHalfByte128(q, quantQueryByte);
1035+
return;
1036+
}
1037+
}
1038+
DefaultESVectorUtilSupport.transposeHalfByteImpl(q, quantQueryByte);
1039+
}
1040+
1041+
private void transposeHalfByte256(int[] q, byte[] quantQueryByte) {
1042+
final int limit = INT_SPECIES_256.loopBound(q.length);
1043+
int i = 0;
1044+
int index = 0;
1045+
for (; i < limit; i += INT_SPECIES_256.length(), index++) {
1046+
IntVector v = IntVector.fromArray(INT_SPECIES_256, q, i);
1047+
1048+
int lowerByte = v.and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1049+
int lowerMiddleByte = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1050+
int upperMiddleByte = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1051+
int upperByte = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, SHIFTS_256).reduceLanes(VectorOperators.OR);
1052+
1053+
quantQueryByte[index] = (byte) lowerByte;
1054+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1055+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1056+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1057+
1058+
}
1059+
if (i == q.length) {
1060+
return; // all done
1061+
}
1062+
int lowerByte = 0;
1063+
int lowerMiddleByte = 0;
1064+
int upperMiddleByte = 0;
1065+
int upperByte = 0;
1066+
for (int j = 7; i < q.length; j--, i++) {
1067+
lowerByte |= (q[i] & 1) << j;
1068+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
1069+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
1070+
upperByte |= ((q[i] >> 3) & 1) << j;
1071+
}
1072+
quantQueryByte[index] = (byte) lowerByte;
1073+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1074+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1075+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1076+
}
1077+
1078+
private void transposeHalfByte128(int[] q, byte[] quantQueryByte) {
1079+
final int limit = INT_SPECIES_128.loopBound(q.length) - INT_SPECIES_128.length();
1080+
int i = 0;
1081+
int index = 0;
1082+
for (; i < limit; i += 2 * INT_SPECIES_128.length(), index++) {
1083+
IntVector v = IntVector.fromArray(INT_SPECIES_128, q, i);
1084+
1085+
var lowerByteHigh = v.and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1086+
var lowerMiddleByteHigh = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1087+
var upperMiddleByteHigh = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1088+
var upperByteHigh = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, HIGH_SHIFTS_128);
1089+
1090+
v = IntVector.fromArray(INT_SPECIES_128, q, i + INT_SPECIES_128.length());
1091+
var lowerByteLow = v.and(1).lanewise(LSHL, LOW_SHIFTS_128);
1092+
var lowerMiddleByteLow = v.lanewise(ASHR, 1).and(1).lanewise(LSHL, LOW_SHIFTS_128);
1093+
var upperMiddleByteLow = v.lanewise(ASHR, 2).and(1).lanewise(LSHL, LOW_SHIFTS_128);
1094+
var upperByteLow = v.lanewise(ASHR, 3).and(1).lanewise(LSHL, LOW_SHIFTS_128);
1095+
1096+
int lowerByte = lowerByteHigh.lanewise(OR, lowerByteLow).reduceLanes(OR);
1097+
int lowerMiddleByte = lowerMiddleByteHigh.lanewise(OR, lowerMiddleByteLow).reduceLanes(OR);
1098+
int upperMiddleByte = upperMiddleByteHigh.lanewise(OR, upperMiddleByteLow).reduceLanes(OR);
1099+
int upperByte = upperByteHigh.lanewise(OR, upperByteLow).reduceLanes(OR);
1100+
1101+
quantQueryByte[index] = (byte) lowerByte;
1102+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1103+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1104+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1105+
1106+
}
1107+
if (i == q.length) {
1108+
return; // all done
1109+
}
1110+
int lowerByte = 0;
1111+
int lowerMiddleByte = 0;
1112+
int upperMiddleByte = 0;
1113+
int upperByte = 0;
1114+
for (int j = 7; i < q.length; j--, i++) {
1115+
lowerByte |= (q[i] & 1) << j;
1116+
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
1117+
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
1118+
upperByte |= ((q[i] >> 3) & 1) << j;
1119+
}
1120+
quantQueryByte[index] = (byte) lowerByte;
1121+
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
1122+
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
1123+
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
1124+
}
10241125
}

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,20 @@ public void testPackAsBinary() {
370370
assertArrayEquals(packedLegacy, packed);
371371
}
372372

373+
public void testTransposeHalfByte() {
374+
int dims = randomIntBetween(16, 2048);
375+
int[] toPack = new int[dims];
376+
for (int i = 0; i < dims; i++) {
377+
toPack[i] = randomInt(15);
378+
}
379+
int length = 4 * BQVectorUtils.discretize(dims, 64) / 8;
380+
byte[] packed = new byte[length];
381+
byte[] packedLegacy = new byte[length];
382+
defaultedProvider.getVectorUtilSupport().transposeHalfByte(toPack, packedLegacy);
383+
defOrPanamaProvider.getVectorUtilSupport().transposeHalfByte(toPack, packed);
384+
assertArrayEquals(packedLegacy, packed);
385+
}
386+
373387
private float[] generateRandomVector(int size) {
374388
float[] vector = new float[size];
375389
for (int i = 0; i < size; ++i) {

server/src/main/java/org/elasticsearch/index/codec/vectors/BQSpaceUtils.java

Lines changed: 3 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
*/
2020
package org.elasticsearch.index.codec.vectors;
2121

22+
import org.elasticsearch.simdvec.ESVectorUtil;
23+
2224
/** Utility class for quantization calculations */
2325
public class BQSpaceUtils {
2426

@@ -117,48 +119,7 @@ public static void transposeHalfByteLegacy(byte[] q, byte[] quantQueryByte) {
117119
* @param quantQueryByte the byte array to store the transposed query vector
118120
* */
119121
public static void transposeHalfByte(int[] q, byte[] quantQueryByte) {
120-
int limit = q.length - 7;
121-
int i = 0;
122-
int index = 0;
123-
for (; i < limit; i += 8, index++) {
124-
assert q[i] >= 0 && q[i] <= 15;
125-
assert q[i + 1] >= 0 && q[i + 1] <= 15;
126-
assert q[i + 2] >= 0 && q[i + 2] <= 15;
127-
assert q[i + 3] >= 0 && q[i + 3] <= 15;
128-
assert q[i + 4] >= 0 && q[i + 4] <= 15;
129-
assert q[i + 5] >= 0 && q[i + 5] <= 15;
130-
assert q[i + 6] >= 0 && q[i + 6] <= 15;
131-
assert q[i + 7] >= 0 && q[i + 7] <= 15;
132-
int lowerByte = (q[i] & 1) << 7 | (q[i + 1] & 1) << 6 | (q[i + 2] & 1) << 5 | (q[i + 3] & 1) << 4 | (q[i + 4] & 1) << 3 | (q[i
133-
+ 5] & 1) << 2 | (q[i + 6] & 1) << 1 | (q[i + 7] & 1);
134-
int lowerMiddleByte = ((q[i] >> 1) & 1) << 7 | ((q[i + 1] >> 1) & 1) << 6 | ((q[i + 2] >> 1) & 1) << 5 | ((q[i + 3] >> 1) & 1)
135-
<< 4 | ((q[i + 4] >> 1) & 1) << 3 | ((q[i + 5] >> 1) & 1) << 2 | ((q[i + 6] >> 1) & 1) << 1 | ((q[i + 7] >> 1) & 1);
136-
int upperMiddleByte = ((q[i] >> 2) & 1) << 7 | ((q[i + 1] >> 2) & 1) << 6 | ((q[i + 2] >> 2) & 1) << 5 | ((q[i + 3] >> 2) & 1)
137-
<< 4 | ((q[i + 4] >> 2) & 1) << 3 | ((q[i + 5] >> 2) & 1) << 2 | ((q[i + 6] >> 2) & 1) << 1 | ((q[i + 7] >> 2) & 1);
138-
int upperByte = ((q[i] >> 3) & 1) << 7 | ((q[i + 1] >> 3) & 1) << 6 | ((q[i + 2] >> 3) & 1) << 5 | ((q[i + 3] >> 3) & 1) << 4
139-
| ((q[i + 4] >> 3) & 1) << 3 | ((q[i + 5] >> 3) & 1) << 2 | ((q[i + 6] >> 3) & 1) << 1 | ((q[i + 7] >> 3) & 1);
140-
quantQueryByte[index] = (byte) lowerByte;
141-
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
142-
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
143-
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
144-
}
145-
if (i == q.length) {
146-
return; // all done
147-
}
148-
int lowerByte = 0;
149-
int lowerMiddleByte = 0;
150-
int upperMiddleByte = 0;
151-
int upperByte = 0;
152-
for (int j = 7; i < q.length; j--, i++) {
153-
lowerByte |= (q[i] & 1) << j;
154-
lowerMiddleByte |= ((q[i] >> 1) & 1) << j;
155-
upperMiddleByte |= ((q[i] >> 2) & 1) << j;
156-
upperByte |= ((q[i] >> 3) & 1) << j;
157-
}
158-
quantQueryByte[index] = (byte) lowerByte;
159-
quantQueryByte[index + quantQueryByte.length / 4] = (byte) lowerMiddleByte;
160-
quantQueryByte[index + quantQueryByte.length / 2] = (byte) upperMiddleByte;
161-
quantQueryByte[index + 3 * quantQueryByte.length / 4] = (byte) upperByte;
122+
ESVectorUtil.transposeHalfByte(q, quantQueryByte);
162123
}
163124

164125
/**

0 commit comments

Comments
 (0)