Skip to content

Commit 987dd5e

Browse files
committed
Add float-bit panama implementation
1 parent 1493794 commit 987dd5e

File tree

3 files changed

+68
-3
lines changed

3 files changed

+68
-3
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/DistanceFunctionBenchmark.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,11 @@ public void findBenchmarkImpl() {
148148

149149
float[] floatDocVector = new float[dims];
150150
byte[] byteDocVector = new byte[dims];
151-
byte[] bitDocVector = new byte[dims/8];
151+
byte[] bitDocVector = new byte[dims / 8];
152152

153153
float[] floatQueryVector = new float[dims];
154154
byte[] byteQueryVector = new byte[dims];
155-
byte[] bitQueryVector = new byte[dims/8];
155+
byte[] bitQueryVector = new byte[dims / 8];
156156

157157
r.nextBytes(byteDocVector);
158158
r.nextBytes(bitDocVector);

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ public static int ipByteBitImpl(byte[] q, byte[] d) {
6969
}
7070

7171
public static float ipFloatBitImpl(float[] q, byte[] d) {
72+
return ipFloatBitImpl(q, d, 0);
73+
}
74+
75+
static float ipFloatBitImpl(float[] q, byte[] d, int start) {
7276
assert q.length == d.length * Byte.SIZE;
7377
float acc0 = 0;
7478
float acc1 = 0;
7579
float acc2 = 0;
7680
float acc3 = 0;
7781
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
78-
for (int i = 0; i < d.length; i++) {
82+
for (int i = start; i < d.length; i++) {
7983
byte mask = d[i];
8084
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
8185
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import jdk.incubator.vector.ByteVector;
13+
import jdk.incubator.vector.FloatVector;
1314
import jdk.incubator.vector.IntVector;
1415
import jdk.incubator.vector.LongVector;
16+
import jdk.incubator.vector.VectorMask;
1517
import jdk.incubator.vector.VectorOperators;
1618
import jdk.incubator.vector.VectorShape;
1719
import jdk.incubator.vector.VectorSpecies;
@@ -55,6 +57,13 @@ public int ipByteBit(byte[] q, byte[] d) {
5557

5658
@Override
5759
public float ipFloatBit(float[] q, byte[] d) {
60+
if (q.length >= 16) {
61+
if (VECTOR_BITSIZE >= 512) {
62+
return ipFloatBit512(q, d);
63+
} else if (VECTOR_BITSIZE == 256) {
64+
return ipFloatBit256(q, d);
65+
}
66+
}
5867
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
5968
}
6069

@@ -165,4 +174,56 @@ public static long ipByteBin128(byte[] q, byte[] d) {
165174
}
166175
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
167176
}
177+
178+
private static final VectorSpecies<Float> FLOAT_SPECIES_8 = FloatVector.SPECIES_256;
179+
private static final VectorSpecies<Float> FLOAT_SPECIES_16 = FloatVector.SPECIES_512;
180+
181+
private static long reverse(byte b) {
182+
// see https://graphics.stanford.edu/~seander/bithacks.html#ReverseByteWith64Bits
183+
return ((((b & 0xff) * 0x80200802L) & 0x0884422110L) * 0x0101010101L >> 32) & 0xff;
184+
}
185+
186+
static float ipFloatBit512(float[] q, byte[] d) {
187+
assert q.length == d.length * Byte.SIZE;
188+
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_16);
189+
190+
int i = 0;
191+
for (; i < FLOAT_SPECIES_16.loopBound(q.length); i += FLOAT_SPECIES_16.length()) {
192+
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_16, q, i);
193+
// use the two bytes corresponding to the same sections
194+
// of the bit vector as a mask for addition
195+
long maskBits = reverse(d[i / 8]) | reverse(d[i / 8 + 1]) << 8;
196+
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_16, maskBits));
197+
}
198+
199+
float sum = acc.reduceLanes(VectorOperators.ADD);
200+
if (i < q.length) {
201+
// do the tail
202+
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);
203+
}
204+
205+
return sum;
206+
}
207+
208+
static float ipFloatBit256(float[] q, byte[] d) {
209+
assert q.length == d.length * Byte.SIZE;
210+
FloatVector acc = FloatVector.zero(FLOAT_SPECIES_8);
211+
212+
int i = 0;
213+
for (; i < FLOAT_SPECIES_8.loopBound(q.length); i += FLOAT_SPECIES_8.length()) {
214+
FloatVector floats = FloatVector.fromArray(FLOAT_SPECIES_8, q, i);
215+
// use the byte corresponding to the same section
216+
// of the bit vector as a mask for addition
217+
long maskBits = reverse(d[i / 8]);
218+
acc = acc.add(floats, VectorMask.fromLong(FLOAT_SPECIES_8, maskBits));
219+
}
220+
221+
float sum = acc.reduceLanes(VectorOperators.ADD);
222+
if (i < q.length) {
223+
// do the tail
224+
sum += DefaultESVectorUtilSupport.ipFloatBitImpl(q, d, i);
225+
}
226+
227+
return sum;
228+
}
168229
}

0 commit comments

Comments
 (0)