Skip to content

Commit 5bb5635

Browse files
committed
doh
1 parent 9e9de6a commit 5bb5635

File tree

1 file changed

+47
-4
lines changed

1 file changed

+47
-4
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/vectorization/MemorySegmentES91OSQVectorsScorer.java

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import org.apache.lucene.index.VectorSimilarityFunction;
2020
import org.apache.lucene.store.IndexInput;
21+
import org.apache.lucene.util.BitUtil;
2122
import org.apache.lucene.util.VectorUtil;
2223
import org.elasticsearch.simdvec.ES91OSQVectorsScorer;
2324

@@ -102,12 +103,26 @@ private long quantizeScore128(byte[] q) throws IOException {
102103
subRet3 += sum3.reduceLanes(VectorOperators.ADD);
103104
// tail as bytes
104105
in.seek(i + offset);
106+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
107+
final long value = in.readLong();
108+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
109+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
110+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
111+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
112+
}
113+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
114+
final int value = in.readInt();
115+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
116+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
117+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
118+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
119+
}
105120
for (; i < length; i++) {
106121
int dValue = in.readByte() & 0xFF;
107-
subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF);
108-
subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF);
109-
subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF);
110-
subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF);
122+
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
123+
subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF);
124+
subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF);
125+
subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF);
111126
}
112127
return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3);
113128
}
@@ -166,6 +181,20 @@ private long quantizeScore256(byte[] q) throws IOException {
166181
}
167182
// tail as bytes
168183
in.seek(i + offset);
184+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
185+
final long value = in.readLong();
186+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
187+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
188+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
189+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
190+
}
191+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
192+
final int value = in.readInt();
193+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
194+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
195+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
196+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
197+
}
169198
for (; i < length; i++) {
170199
int dValue = in.readByte() & 0xFF;
171200
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);
@@ -253,6 +282,20 @@ private long quantizeScore512(byte[] q) throws IOException {
253282
}
254283
// tail as bytes
255284
in.seek(i + offset);
285+
for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) {
286+
final long value = in.readLong();
287+
subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value);
288+
subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value);
289+
subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value);
290+
subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value);
291+
}
292+
for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) {
293+
final int value = in.readInt();
294+
subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value);
295+
subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value);
296+
subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value);
297+
subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value);
298+
}
256299
for (; i < length; i++) {
257300
int dValue = in.readByte() & 0xFF;
258301
subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF);

0 commit comments

Comments
 (0)