Skip to content

Commit 8c1ac09

Browse files
committed
Add byte-bit implementation
1 parent 987dd5e commit 8c1ac09

File tree

3 files changed

+79
-16
lines changed

3 files changed

+79
-16
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,17 @@ public float ipFloatByte(float[] q, byte[] d) {
4545
}
4646

4747
public static int ipByteBitImpl(byte[] q, byte[] d) {
48+
return ipByteBitImpl(q, d, 0);
49+
}
50+
51+
public static int ipByteBitImpl(byte[] q, byte[] d, int start) {
4852
assert q.length == d.length * Byte.SIZE;
4953
int acc0 = 0;
5054
int acc1 = 0;
5155
int acc2 = 0;
5256
int acc3 = 0;
5357
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
54-
for (int i = 0; i < d.length; i++) {
58+
for (int i = start; i < d.length; i++) {
5559
byte mask = d[i];
5660
// Make sure its just 1 or 0
5761

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 70 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@
1313
import jdk.incubator.vector.FloatVector;
1414
import jdk.incubator.vector.IntVector;
1515
import jdk.incubator.vector.LongVector;
16+
import jdk.incubator.vector.Vector;
1617
import jdk.incubator.vector.VectorMask;
1718
import jdk.incubator.vector.VectorOperators;
1819
import jdk.incubator.vector.VectorShape;
1920
import jdk.incubator.vector.VectorSpecies;
2021

22+
import org.apache.lucene.util.BitUtil;
2123
import org.apache.lucene.util.Constants;
2224

2325
public final class PanamaESVectorUtilSupport implements ESVectorUtilSupport {
@@ -52,6 +54,13 @@ public long ipByteBinByte(byte[] q, byte[] d) {
5254

5355
@Override
5456
public int ipByteBit(byte[] q, byte[] d) {
57+
if (d.length >= 16 && HAS_FAST_INTEGER_VECTORS) {
58+
if (VECTOR_BITSIZE >= 512) {
59+
return ipByteBit512(q, d);
60+
} else if (VECTOR_BITSIZE == 256) {
61+
return ipByteBit256(q, d);
62+
}
63+
}
5564
return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
5665
}
5766

@@ -175,25 +184,71 @@ public static long ipByteBin128(byte[] q, byte[] d) {
175184
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
176185
}
177186

178-
private static final VectorSpecies<Float> FLOAT_SPECIES_8 = FloatVector.SPECIES_256;
179-
private static final VectorSpecies<Float> FLOAT_SPECIES_16 = FloatVector.SPECIES_512;
187+
private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
188+
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_INT_512 = VectorSpecies.of(
189+
byte.class,
190+
VectorShape.forBitSize(INT_SPECIES_512.vectorBitSize() / Integer.BYTES)
191+
);
192+
private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
193+
private static final VectorSpecies<Byte> BYTE_SPECIES_FOR_INT_256 = VectorSpecies.of(
194+
byte.class,
195+
VectorShape.forBitSize(INT_SPECIES_256.vectorBitSize() / Integer.BYTES)
196+
);
197+
198+
static int ipByteBit512(byte[] q, byte[] d) {
199+
assert q.length == d.length * Byte.SIZE;
200+
IntVector acc = IntVector.zero(INT_SPECIES_512);
201+
202+
int i = 0;
203+
for (; i < BYTE_SPECIES_FOR_INT_512.loopBound(q.length); i += BYTE_SPECIES_FOR_INT_512.length()) {
204+
Vector<Integer> bytes = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_512, q, i).castShape(INT_SPECIES_512, 0);
205+
long maskBits = Integer.reverse((short) BitUtil.VH_BE_SHORT.get(d, i / 8)) >> 16;
206+
207+
acc = acc.add(bytes, VectorMask.fromLong(INT_SPECIES_512, maskBits));
208+
}
209+
210+
int sum = acc.reduceLanes(VectorOperators.ADD);
211+
if (i < q.length) {
212+
// do the tail
213+
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
214+
}
215+
return sum;
216+
}
217+
218+
static int ipByteBit256(byte[] q, byte[] d) {
219+
assert q.length == d.length * Byte.SIZE;
220+
IntVector acc = IntVector.zero(INT_SPECIES_256);
221+
222+
int i = 0;
223+
for (; i < BYTE_SPECIES_FOR_INT_256.loopBound(q.length); i += BYTE_SPECIES_FOR_INT_256.length()) {
224+
Vector<Integer> bytes = ByteVector.fromArray(BYTE_SPECIES_FOR_INT_256, q, i).castShape(INT_SPECIES_256, 0);
225+
long maskBits = Integer.reverse(d[i / 8]) >> 24;
226+
227+
acc = acc.add(bytes, VectorMask.fromLong(INT_SPECIES_256, maskBits));
228+
}
180229

181-
private static long reverse(byte b) {
182-
// see https://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits
183-
return ((((b & 0xff) * 0x80200802L) & 0x0884422110L) * 0x0101010101L >> 32) & 0xff;
230+
int sum = acc.reduceLanes(VectorOperators.ADD);
231+
if (i < q.length) {
232+
// do the tail
233+
sum += DefaultESVectorUtilSupport.ipByteBitImpl(q, d, i);
234+
}
235+
return sum;
184236
}
185237

238+
private static final VectorSpecies<Float> FLOAT_SPECIES_512 = FloatVector.SPECIES_512;
239+
private static final VectorSpecies<Float> FLOAT_SPECIES_256 = FloatVector.SPECIES_256;
240+
186241
static float ipFloatBit512(float[] q, byte[] d) {
187242
assert q.length == d.length * Byte.SIZE;
188-
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_16);
243+
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_512);
189244

190245
int i = 0;
191-
for (; i < FLOAT_SPECIES_16.loopBound(q.length); i += FLOAT_SPECIES_16.length()) {
192-
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_16, q, i);
246+
for (; i < FLOAT_SPECIES_512.loopBound(q.length); i += FLOAT_SPECIES_512.length()) {
247+
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_512, q, i);
193248
// use the two bytes corresponding to the same sections
194249
// of the bit vector as a mask for addition
195-
long maskBits = reverse(d[i / 8]) | reverse(d[i / 8 + 1]) << 8;
196-
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_16, maskBits));
250+
long maskBits = Integer.reverse((short) BitUtil.VH_BE_SHORT.get(d, i / 8)) >> 16;
251+
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_512, maskBits));
197252
}
198253

199254
float sum = acc.reduceLanes(VectorOperators.ADD);
@@ -207,15 +262,15 @@ static float ipFloatBit512(float[] q, byte[] d) {
207262

208263
static float ipFloatBit256(float[] q, byte[] d) {
209264
assert q.length == d.length * Byte.SIZE;
210-
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_8);
265+
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_256);
211266

212267
int i = 0;
213-
for (; i < FLOAT_SPECIES_8.loopBound(q.length); i += FLOAT_SPECIES_8.length()) {
214-
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_8, q, i);
268+
for (; i < FLOAT_SPECIES_256.loopBound(q.length); i += FLOAT_SPECIES_256.length()) {
269+
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_256, q, i);
215270
// use the byte corresponding to the same section
216271
// of the bit vector as a mask for addition
217-
long maskBits = reverse(d[i / 8]);
218-
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_8, maskBits));
272+
long maskBits = Integer.reverse(d[i / 8]) >> 24;
273+
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_256, maskBits));
219274
}
220275

221276
float sum = acc.reduceLanes(VectorOperators.ADD);

libs/simdvec/src/test/java/org/elasticsearch/simdvec/ESVectorUtilTests.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ public void testIpByteBit() {
2727
random().nextBytes(q);
2828
int expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
2929
assertEquals(expected, ESVectorUtil.ipByteBit(q, d));
30+
assertEquals(expected, defaultedProvider.getVectorUtilSupport().ipByteBit(q, d));
31+
assertEquals(expected, defOrPanamaProvider.getVectorUtilSupport().ipByteBit(q, d));
3032
}
3133

3234
public void testIpFloatBit() {
@@ -37,6 +39,8 @@ public void testIpFloatBit() {
3739
}
3840
float expected = q[1] + q[2] + q[6] + q[8] + q[10] + q[13] + q[14] + q[15];
3941
assertEquals(expected, ESVectorUtil.ipFloatBit(q, d), 1e-6);
42+
assertEquals(expected, defaultedProvider.getVectorUtilSupport().ipFloatBit(q, d), 1e-6);
43+
assertEquals(expected, defOrPanamaProvider.getVectorUtilSupport().ipFloatBit(q, d), 1e-6);
4044
}
4145

4246
public void testIpFloatByte() {

0 commit comments

Comments
 (0)