99package org .elasticsearch .simdvec .internal ;
1010
1111import jdk .incubator .vector .ByteVector ;
12+ import jdk .incubator .vector .FloatVector ;
1213import jdk .incubator .vector .IntVector ;
1314import jdk .incubator .vector .ShortVector ;
1415import jdk .incubator .vector .Vector ;
16+ import jdk .incubator .vector .VectorOperators ;
1517import jdk .incubator .vector .VectorShape ;
1618import jdk .incubator .vector .VectorSpecies ;
1719
20+ import org .apache .lucene .index .VectorSimilarityFunction ;
1821import org .apache .lucene .store .IndexInput ;
22+ import org .apache .lucene .util .VectorUtil ;
1923import org .elasticsearch .simdvec .ES92Int7VectorsScorer ;
2024
2125import java .io .IOException ;
2226import java .lang .foreign .MemorySegment ;
27+ import java .nio .ByteOrder ;
2328
2429import static java .nio .ByteOrder .LITTLE_ENDIAN ;
2530import static jdk .incubator .vector .VectorOperators .ADD ;
2631import static jdk .incubator .vector .VectorOperators .B2I ;
2732import static jdk .incubator .vector .VectorOperators .B2S ;
2833import static jdk .incubator .vector .VectorOperators .S2I ;
34+ import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
35+ import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
2936
3037/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
31- abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7VectorsScorer {
38+ abstract class MemorySegmentES92PanamaInt7VectorsScorer extends ES92Int7VectorsScorer {
3239
3340 private static final VectorSpecies <Byte > BYTE_SPECIES_64 = ByteVector .SPECIES_64 ;
3441 private static final VectorSpecies <Byte > BYTE_SPECIES_128 = ByteVector .SPECIES_128 ;
@@ -41,8 +48,8 @@ abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7Vector
4148 private static final VectorSpecies <Integer > INT_SPECIES_512 = IntVector .SPECIES_512 ;
4249
4350 private static final int VECTOR_BITSIZE ;
44- protected static final VectorSpecies <Float > FLOAT_SPECIES ;
45- protected static final VectorSpecies <Integer > INT_SPECIES ;
51+ private static final VectorSpecies <Float > FLOAT_SPECIES ;
52+ private static final VectorSpecies <Integer > INT_SPECIES ;
4653
4754 static {
4855 // default to platform supported bitsize
@@ -53,12 +60,12 @@ abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7Vector
5360
5461 protected final MemorySegment memorySegment ;
5562
56- public MemorySegmentES92FallBackInt7VectorsScorer (IndexInput in , int dimensions , MemorySegment memorySegment ) {
63+ public MemorySegmentES92PanamaInt7VectorsScorer (IndexInput in , int dimensions , MemorySegment memorySegment ) {
5764 super (in , dimensions );
5865 this .memorySegment = memorySegment ;
5966 }
6067
61- protected long fallbackInt7DotProduct (byte [] q ) throws IOException {
68+ protected long panamaInt7DotProduct (byte [] q ) throws IOException {
6269 assert dimensions == q .length ;
6370 int i = 0 ;
6471 int res = 0 ;
@@ -147,7 +154,7 @@ private int dotProductBody128(byte[] q, int limit) throws IOException {
147154 return acc .reduceLanes (ADD );
148155 }
149156
150- protected void fallbackInt7DotProductBulk (byte [] q , int count , float [] scores ) throws IOException {
157+ protected void panamaInt7DotProductBulk (byte [] q , int count , float [] scores ) throws IOException {
151158 assert dimensions == q .length ;
152159 // only vectorize if we'll at least enter the loop a single time
153160 if (dimensions >= 16 ) {
@@ -249,4 +256,72 @@ private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws I
249256 scores [iter ] = res ;
250257 }
251258 }
259+
260+ protected void applyCorrectionsBulk (
261+ float queryLowerInterval ,
262+ float queryUpperInterval ,
263+ int queryComponentSum ,
264+ float queryAdditionalCorrection ,
265+ VectorSimilarityFunction similarityFunction ,
266+ float centroidDp ,
267+ float [] scores
268+ ) throws IOException {
269+ int limit = FLOAT_SPECIES .loopBound (BULK_SIZE );
270+ int i = 0 ;
271+ long offset = in .getFilePointer ();
272+ float ay = queryLowerInterval ;
273+ float ly = (queryUpperInterval - ay ) * SEVEN_BIT_SCALE ;
274+ float y1 = queryComponentSum ;
275+ for (; i < limit ; i += FLOAT_SPECIES .length ()) {
276+ var ax = FloatVector .fromMemorySegment (FLOAT_SPECIES , memorySegment , offset + i * Float .BYTES , ByteOrder .LITTLE_ENDIAN );
277+ var lx = FloatVector .fromMemorySegment (
278+ FLOAT_SPECIES ,
279+ memorySegment ,
280+ offset + 4 * BULK_SIZE + i * Float .BYTES ,
281+ ByteOrder .LITTLE_ENDIAN
282+ ).sub (ax ).mul (SEVEN_BIT_SCALE );
283+ var targetComponentSums = IntVector .fromMemorySegment (
284+ INT_SPECIES ,
285+ memorySegment ,
286+ offset + 8 * BULK_SIZE + i * Integer .BYTES ,
287+ ByteOrder .LITTLE_ENDIAN
288+ ).convert (VectorOperators .I2F , 0 );
289+ var additionalCorrections = FloatVector .fromMemorySegment (
290+ FLOAT_SPECIES ,
291+ memorySegment ,
292+ offset + 12 * BULK_SIZE + i * Float .BYTES ,
293+ ByteOrder .LITTLE_ENDIAN
294+ );
295+ var qcDist = FloatVector .fromArray (FLOAT_SPECIES , scores , i );
296+ // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
297+ // qcDist;
298+ var res1 = ax .mul (ay ).mul (dimensions );
299+ var res2 = lx .mul (ay ).mul (targetComponentSums );
300+ var res3 = ax .mul (ly ).mul (y1 );
301+ var res4 = lx .mul (ly ).mul (qcDist );
302+ var res = res1 .add (res2 ).add (res3 ).add (res4 );
303+ // For euclidean, we need to invert the score and apply the additional correction, which is
304+ // assumed to be the squared l2norm of the centroid centered vectors.
305+ if (similarityFunction == EUCLIDEAN ) {
306+ res = res .mul (-2 ).add (additionalCorrections ).add (queryAdditionalCorrection ).add (1f );
307+ res = FloatVector .broadcast (FLOAT_SPECIES , 1 ).div (res ).max (0 );
308+ res .intoArray (scores , i );
309+ } else {
310+ // For cosine and max inner product, we need to apply the additional correction, which is
311+ // assumed to be the non-centered dot-product between the vector and the centroid
312+ res = res .add (queryAdditionalCorrection ).add (additionalCorrections ).sub (centroidDp );
313+ if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
314+ res .intoArray (scores , i );
315+ // not sure how to do it better
316+ for (int j = 0 ; j < FLOAT_SPECIES .length (); j ++) {
317+ scores [i + j ] = VectorUtil .scaleMaxInnerProductScore (scores [i + j ]);
318+ }
319+ } else {
320+ res = res .add (1f ).mul (0.5f ).max (0 );
321+ res .intoArray (scores , i );
322+ }
323+ }
324+ }
325+ in .seek (offset + 16L * BULK_SIZE );
326+ }
252327}
0 commit comments