|
18 | 18 |
|
19 | 19 | import org.apache.lucene.index.VectorSimilarityFunction; |
20 | 20 | import org.apache.lucene.store.IndexInput; |
| 21 | +import org.apache.lucene.util.BitUtil; |
21 | 22 | import org.apache.lucene.util.VectorUtil; |
22 | 23 | import org.elasticsearch.simdvec.ES91OSQVectorsScorer; |
23 | 24 |
|
@@ -118,8 +119,22 @@ private long quantizeScore256(byte[] q) throws IOException { |
118 | 119 | subRet2 += sum2.reduceLanes(VectorOperators.ADD); |
119 | 120 | subRet3 += sum3.reduceLanes(VectorOperators.ADD); |
120 | 121 | } |
121 | | - // tail as bytes |
| 122 | + // process scalar tail |
122 | 123 | in.seek(offset); |
| 124 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 125 | + final long value = in.readLong(); |
| 126 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 127 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 128 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 129 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 130 | + } |
| 131 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 132 | + final int value = in.readInt(); |
| 133 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 134 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 135 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 136 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 137 | + } |
123 | 138 | for (; i < length; i++) { |
124 | 139 | int dValue = in.readByte() & 0xFF; |
125 | 140 | subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
@@ -158,14 +173,28 @@ private long quantizeScore128(byte[] q) throws IOException { |
158 | 173 | subRet1 += sum1.reduceLanes(VectorOperators.ADD); |
159 | 174 | subRet2 += sum2.reduceLanes(VectorOperators.ADD); |
160 | 175 | subRet3 += sum3.reduceLanes(VectorOperators.ADD); |
161 | | - // tail as bytes |
| 176 | + // process scalar tail |
162 | 177 | in.seek(offset); |
| 178 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 179 | + final long value = in.readLong(); |
| 180 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 181 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 182 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 183 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 184 | + } |
| 185 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 186 | + final int value = in.readInt(); |
| 187 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 188 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 189 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 190 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 191 | + } |
163 | 192 | for (; i < length; i++) { |
164 | 193 | int dValue = in.readByte() & 0xFF; |
165 | | - subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); |
166 | | - subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); |
167 | | - subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); |
168 | | - subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); |
| 194 | + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
| 195 | + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); |
| 196 | + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); |
| 197 | + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); |
169 | 198 | } |
170 | 199 | return subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); |
171 | 200 | } |
@@ -215,14 +244,28 @@ private void quantizeScore128Bulk(byte[] q, int count, float[] scores) throws IO |
215 | 244 | subRet1 += sum1.reduceLanes(VectorOperators.ADD); |
216 | 245 | subRet2 += sum2.reduceLanes(VectorOperators.ADD); |
217 | 246 | subRet3 += sum3.reduceLanes(VectorOperators.ADD); |
218 | | - // tail as bytes |
| 247 | + // process scalar tail |
219 | 248 | in.seek(offset); |
| 249 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 250 | + final long value = in.readLong(); |
| 251 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 252 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 253 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 254 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 255 | + } |
| 256 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 257 | + final int value = in.readInt(); |
| 258 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 259 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 260 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 261 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 262 | + } |
220 | 263 | for (; i < length; i++) { |
221 | 264 | int dValue = in.readByte() & 0xFF; |
222 | | - subRet0 += Integer.bitCount((dValue & q[i]) & 0xFF); |
223 | | - subRet1 += Integer.bitCount((dValue & q[i + length]) & 0xFF); |
224 | | - subRet2 += Integer.bitCount((dValue & q[i + 2 * length]) & 0xFF); |
225 | | - subRet3 += Integer.bitCount((dValue & q[i + 3 * length]) & 0xFF); |
| 265 | + subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
| 266 | + subRet1 += Integer.bitCount((q[i + length] & dValue) & 0xFF); |
| 267 | + subRet2 += Integer.bitCount((q[i + 2 * length] & dValue) & 0xFF); |
| 268 | + subRet3 += Integer.bitCount((q[i + 3 * length] & dValue) & 0xFF); |
226 | 269 | } |
227 | 270 | scores[iter] = subRet0 + (subRet1 << 1) + (subRet2 << 2) + (subRet3 << 3); |
228 | 271 | } |
@@ -281,8 +324,22 @@ private void quantizeScore256Bulk(byte[] q, int count, float[] scores) throws IO |
281 | 324 | subRet2 += sum2.reduceLanes(VectorOperators.ADD); |
282 | 325 | subRet3 += sum3.reduceLanes(VectorOperators.ADD); |
283 | 326 | } |
284 | | - // tail as bytes |
| 327 | + // process scalar tail |
285 | 328 | in.seek(offset); |
| 329 | + for (final int upperBound = length & -Long.BYTES; i < upperBound; i += Long.BYTES) { |
| 330 | + final long value = in.readLong(); |
| 331 | + subRet0 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i) & value); |
| 332 | + subRet1 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + length) & value); |
| 333 | + subRet2 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 2 * length) & value); |
| 334 | + subRet3 += Long.bitCount((long) BitUtil.VH_LE_LONG.get(q, i + 3 * length) & value); |
| 335 | + } |
| 336 | + for (final int upperBound = length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { |
| 337 | + final int value = in.readInt(); |
| 338 | + subRet0 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i) & value); |
| 339 | + subRet1 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + length) & value); |
| 340 | + subRet2 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 2 * length) & value); |
| 341 | + subRet3 += Integer.bitCount((int) BitUtil.VH_LE_INT.get(q, i + 3 * length) & value); |
| 342 | + } |
286 | 343 | for (; i < length; i++) { |
287 | 344 | int dValue = in.readByte() & 0xFF; |
288 | 345 | subRet0 += Integer.bitCount((q[i] & dValue) & 0xFF); |
|
0 commit comments