|
18 | 18 |
|
19 | 19 | import org.apache.lucene.index.VectorSimilarityFunction; |
20 | 20 | import org.apache.lucene.store.IndexInput; |
| 21 | +import org.apache.lucene.util.BitUtil; |
21 | 22 | import org.apache.lucene.util.VectorUtil; |
22 | 23 | import org.elasticsearch.simdvec.ES91OSQVectorsScorer; |
23 | 24 |
|
@@ -102,12 +103,26 @@ private long quantizeScore128(byte[] q) throws IOException { |
102 | 103 | subRet3 += sum3.reduceLanes(VectorOperators.ADD); |
103 | 104 | // tail as bytes |
104 | 105 | in.seek(i + offset); |
| 106 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 107 | + final long value = in.readLong(); |
| 108 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 109 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 110 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 111 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 112 | + } |
| 113 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 114 | + final int value = in.readInt(); |
| 115 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 116 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 117 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 118 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 119 | + } |
105 | 120 | for (; i < length; i++) { |
106 | 121 | int dValue = in.readByte() & 0xFF; |
107 | | - subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); |
108 | | - subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); |
109 | | - subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); |
110 | | - subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); |
| 122 | + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
| 123 | + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); |
| 124 | + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); |
| 125 | + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); |
111 | 126 | } |
112 | 127 | return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); |
113 | 128 | } |
@@ -166,6 +181,20 @@ private long quantizeScore256(byte[] q) throws IOException { |
166 | 181 | } |
167 | 182 | // tail as bytes |
168 | 183 | in.seek(i + offset); |
| 184 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 185 | + final long value = in.readLong(); |
| 186 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 187 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 188 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 189 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 190 | + } |
| 191 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 192 | + final int value = in.readInt(); |
| 193 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 194 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 195 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 196 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 197 | + } |
169 | 198 | for (; i < length; i++) { |
170 | 199 | int dValue = in.readByte() & 0xFF; |
171 | 200 | subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
@@ -253,6 +282,20 @@ private long quantizeScore512(byte[] q) throws IOException { |
253 | 282 | } |
254 | 283 | // tail as bytes |
255 | 284 | in.seek(i + offset); |
| 285 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 286 | + final long value = in.readLong(); |
| 287 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 288 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 289 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 290 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 291 | + } |
| 292 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 293 | + final int value = in.readInt(); |
| 294 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 295 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 296 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 297 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 298 | + } |
256 | 299 | for (; i < length; i++) { |
257 | 300 | int dValue = in.readByte() & 0xFF; |
258 | 301 | subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
|
0 commit comments