Skip to content

Commit 1771d00

Browse files
authored
Speed up tail computation in MemorySegmentES91OSQVectorsScorer (#132001)
1 parent ee671f1 commit 1771d00

File tree

2 files changed

+70
-13
lines changed

2 files changed

+70
-13
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/OSQScorerBenchmark.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public class OSQScorerBenchmark {
5252
LogConfigurator.configureESLogging(); // native access requires logging to be initialized
5353
}
5454

55-
@Param({ "1024" })
55+
@Param({ "384", "782", "1024" })
5656
int dims;
5757

5858
int length;

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 69 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.apache.lucene.index.VectorSimilarityFunction;
2020
import org.apache.lucene.store.IndexInput;
21+
import org.apache.lucene.util.BitUtil;
2122
import org.apache.lucene.util.VectorUtil;
2223
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2324

@@ -118,8 +119,22 @@ private long quantizeScore256(byte[] q) throws IOException {
118119
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
119120
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
120121
}
121-
// tail as bytes
122+
// process scalar tail
122123
in.seek(offset);
124+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
125+
final long value = in.readLong();
126+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
127+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
128+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
129+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
130+
}
131+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
132+
final int value = in.readInt();
133+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
134+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
135+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
136+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
137+
}
123138
for (; i < length; i++) {
124139
int dValue = in.readByte() & 0xFF;
125140
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
@@ -158,14 +173,28 @@ private long quantizeScore128(byte[] q) throws IOException {
158173
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
159174
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
160175
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
161-
// tail as bytes
176+
// process scalar tail
162177
in.seek(offset);
178+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
179+
final long value = in.readLong();
180+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
181+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
182+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
183+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
184+
}
185+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
186+
final int value = in.readInt();
187+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
188+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
189+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
190+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
191+
}
163192
for (; i < length; i++) {
164193
int dValue = in.readByte() & 0xFF;
165-
subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
166-
subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
167-
subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
168-
subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
194+
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
195+
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
196+
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
197+
subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
169198
}
170199
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
171200
}
@@ -215,14 +244,28 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO
215244
subRet1 += sum1.reduceLanes(VectorOperators.ADD);
216245
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
217246
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
218-
// tail as bytes
247+
// process scalar tail
219248
in.seek(offset);
249+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
250+
final long value = in.readLong();
251+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
252+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
253+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
254+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
255+
}
256+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
257+
final int value = in.readInt();
258+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
259+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
260+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
261+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
262+
}
220263
for (; i < length; i++) {
221264
int dValue = in.readByte() & 0xFF;
222-
subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
223-
subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
224-
subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
225-
subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
265+
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
266+
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
267+
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
268+
subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
226269
}
227270
scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
228271
}
@@ -281,8 +324,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO
281324
subRet2 += sum2.reduceLanes(VectorOperators.ADD);
282325
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
283326
}
284-
// tail as bytes
327+
// process scalar tail
285328
in.seek(offset);
329+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
330+
final long value = in.readLong();
331+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
332+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
333+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
334+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
335+
}
336+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
337+
final int value = in.readInt();
338+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
339+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
340+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
341+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
342+
}
286343
for (; i < length; i++) {
287344
int dValue = in.readByte() & 0xFF;
288345
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);

0 commit comments

Comments
 (0)