Skip to content

Commit ff18d1b

Browse files
authored
Speed up bit compared with floats or bytes script operations (#117199) (#117841)
Instead of doing an "if" statement, which doesn't lend itself to vectorization, I switched to expand to the bits and multiply the 1s and 0s. This led to a marginal speed improvement on ARM. I expect that Panama vector could be used here to be even faster, but I didn't want to spend anymore time on this for the time being. ``` Benchmark (dims) Mode Cnt Score Error Units IpBitVectorScorerBenchmark.dotProductByteIfStatement 768 thrpt 5 2.952 ± 0.026 ops/us IpBitVectorScorerBenchmark.dotProductByteUnwrap 768 thrpt 5 4.017 ± 0.068 ops/us IpBitVectorScorerBenchmark.dotProductFloatIfStatement 768 thrpt 5 2.987 ± 0.124 ops/us IpBitVectorScorerBenchmark.dotProductFloatUnwrap 768 thrpt 5 4.726 ± 0.136 ops/us ``` Benchmark I used. https://gist.github.com/benwtrent/b0edb3975d2f03356c1a5ea84c72abc9
1 parent f3a58b6 commit ff18d1b

File tree

5 files changed

+86
-21
lines changed

5 files changed

+86
-21
lines changed

docs/changelog/117199.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 117199
2+
summary: Speed up bit compared with floats or bytes script operations
3+
area: Vector Search
4+
type: enhancement
5+
issues: []

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ESVectorUtil.java

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -61,17 +61,7 @@ public static int ipByteBit(byte[] q, byte[] d) {
6161
if (q.length != d.length * Byte.SIZE) {
6262
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
6363
}
64-
int result = 0;
65-
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
66-
for (int i = 0; i < d.length; i++) {
67-
byte mask = d[i];
68-
for (int j = Byte.SIZE - 1; j >= 0; j--) {
69-
if ((mask & (1 << j)) != 0) {
70-
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
71-
}
72-
}
73-
}
74-
return result;
64+
return IMPL.ipByteBit(q, d);
7565
}
7666

7767
/**
@@ -87,16 +77,7 @@ public static float ipFloatBit(float[] q, byte[] d) {
8777
if (q.length != d.length * Byte.SIZE) {
8878
throw new IllegalArgumentException("vector dimensions incompatible: " + q.length + "!= " + Byte.SIZE + " x " + d.length);
8979
}
90-
float result = 0;
91-
for (int i = 0; i < d.length; i++) {
92-
byte mask = d[i];
93-
for (int j = Byte.SIZE - 1; j >= 0; j--) {
94-
if ((mask & (1 << j)) != 0) {
95-
result += q[i * Byte.SIZE + Byte.SIZE - 1 - j];
96-
}
97-
}
98-
}
99-
return result;
80+
return IMPL.ipFloatBit(q, d);
10081
}
10182

10283
/**

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/DefaultESVectorUtilSupport.java

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,81 @@
1010
package org.elasticsearch.simdvec.internal.vectorization;
1111

1212
import org.apache.lucene.util.BitUtil;
13+
import org.apache.lucene.util.Constants;
1314

1415
final class DefaultESVectorUtilSupport implements ESVectorUtilSupport {
1516

17+
private static float fma(float a, float b, float c) {
18+
if (Constants.HAS_FAST_SCALAR_FMA) {
19+
return Math.fma(a, b, c);
20+
} else {
21+
return a * b + c;
22+
}
23+
}
24+
1625
DefaultESVectorUtilSupport() {}
1726

1827
@Override
1928
public long ipByteBinByte(byte[] q, byte[] d) {
2029
return ipByteBinByteImpl(q, d);
2130
}
2231

32+
@Override
33+
public int ipByteBit(byte[] q, byte[] d) {
34+
return ipByteBitImpl(q, d);
35+
}
36+
37+
@Override
38+
public float ipFloatBit(float[] q, byte[] d) {
39+
return ipFloatBitImpl(q, d);
40+
}
41+
42+
public static int ipByteBitImpl(byte[] q, byte[] d) {
43+
assert q.length == d.length * Byte.SIZE;
44+
int acc0 = 0;
45+
int acc1 = 0;
46+
int acc2 = 0;
47+
int acc3 = 0;
48+
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
49+
for (int i = 0; i < d.length; i++) {
50+
byte mask = d[i];
51+
// Make sure its just 1 or 0
52+
53+
acc0 += q[i * Byte.SIZE + 0] * ((mask >> 7) & 1);
54+
acc1 += q[i * Byte.SIZE + 1] * ((mask >> 6) & 1);
55+
acc2 += q[i * Byte.SIZE + 2] * ((mask >> 5) & 1);
56+
acc3 += q[i * Byte.SIZE + 3] * ((mask >> 4) & 1);
57+
58+
acc0 += q[i * Byte.SIZE + 4] * ((mask >> 3) & 1);
59+
acc1 += q[i * Byte.SIZE + 5] * ((mask >> 2) & 1);
60+
acc2 += q[i * Byte.SIZE + 6] * ((mask >> 1) & 1);
61+
acc3 += q[i * Byte.SIZE + 7] * ((mask >> 0) & 1);
62+
}
63+
return acc0 + acc1 + acc2 + acc3;
64+
}
65+
66+
public static float ipFloatBitImpl(float[] q, byte[] d) {
67+
assert q.length == d.length * Byte.SIZE;
68+
float acc0 = 0;
69+
float acc1 = 0;
70+
float acc2 = 0;
71+
float acc3 = 0;
72+
// now combine the two vectors, summing the byte dimensions where the bit in d is `1`
73+
for (int i = 0; i < d.length; i++) {
74+
byte mask = d[i];
75+
acc0 = fma(q[i * Byte.SIZE + 0], (mask >> 7) & 1, acc0);
76+
acc1 = fma(q[i * Byte.SIZE + 1], (mask >> 6) & 1, acc1);
77+
acc2 = fma(q[i * Byte.SIZE + 2], (mask >> 5) & 1, acc2);
78+
acc3 = fma(q[i * Byte.SIZE + 3], (mask >> 4) & 1, acc3);
79+
80+
acc0 = fma(q[i * Byte.SIZE + 4], (mask >> 3) & 1, acc0);
81+
acc1 = fma(q[i * Byte.SIZE + 5], (mask >> 2) & 1, acc1);
82+
acc2 = fma(q[i * Byte.SIZE + 6], (mask >> 1) & 1, acc2);
83+
acc3 = fma(q[i * Byte.SIZE + 7], (mask >> 0) & 1, acc3);
84+
}
85+
return acc0 + acc1 + acc2 + acc3;
86+
}
87+
2388
public static long ipByteBinByteImpl(byte[] q, byte[] d) {
2489
long ret = 0;
2590
int size = d.length;

libs/simdvec/src/main/java/org/elasticsearch/simdvec/internal/vectorization/ESVectorUtilSupport.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,8 @@ public interface ESVectorUtilSupport {
1414
short B_QUERY = 4;
1515

1616
long ipByteBinByte(byte[] q, byte[] d);
17+
18+
int ipByteBit(byte[] q, byte[] d);
19+
20+
float ipFloatBit(float[] q, byte[] d);
1721
}

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/PanamaESVectorUtilSupport.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,16 @@ public long ipByteBinByte(byte[] q, byte[] d) {
4848
return DefaultESVectorUtilSupport.ipByteBinByteImpl(q, d);
4949
}
5050

51+
@Override
52+
public int ipByteBit(byte[] q, byte[] d) {
53+
return DefaultESVectorUtilSupport.ipByteBitImpl(q, d);
54+
}
55+
56+
@Override
57+
public float ipFloatBit(float[] q, byte[] d) {
58+
return DefaultESVectorUtilSupport.ipFloatBitImpl(q, d);
59+
}
60+
5161
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
5262
private static final VectorSpecies<Byte> BYTE_SPECIES_256 = ByteVector.SPECIES_256;
5363

0 commit comments

Comments
 (0)