88 */ 
99package  org .elasticsearch .simdvec .internal ;
1010
11- import  jdk .incubator .vector .ByteVector ;
12- import  jdk .incubator .vector .FloatVector ;
13- import  jdk .incubator .vector .IntVector ;
14- import  jdk .incubator .vector .ShortVector ;
15- import  jdk .incubator .vector .Vector ;
16- import  jdk .incubator .vector .VectorOperators ;
17- import  jdk .incubator .vector .VectorShape ;
18- import  jdk .incubator .vector .VectorSpecies ;
19- 
2011import  org .apache .lucene .index .VectorSimilarityFunction ;
2112import  org .apache .lucene .store .IndexInput ;
22- import  org .apache .lucene .util .VectorUtil ;
23- import  org .elasticsearch .simdvec .ES92Int7VectorsScorer ;
2413
2514import  java .io .IOException ;
2615import  java .lang .foreign .MemorySegment ;
27- import  java .nio .ByteOrder ;
28- 
29- import  static  java .nio .ByteOrder .LITTLE_ENDIAN ;
30- import  static  jdk .incubator .vector .VectorOperators .ADD ;
31- import  static  jdk .incubator .vector .VectorOperators .B2I ;
32- import  static  jdk .incubator .vector .VectorOperators .B2S ;
33- import  static  jdk .incubator .vector .VectorOperators .S2I ;
34- import  static  org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
35- import  static  org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
3616
3717/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/ 
38- public  final  class  MemorySegmentES92Int7VectorsScorer  extends  ES92Int7VectorsScorer  {
39- 
40-     private  static  final  VectorSpecies <Byte > BYTE_SPECIES_64  = ByteVector .SPECIES_64 ;
41-     private  static  final  VectorSpecies <Byte > BYTE_SPECIES_128  = ByteVector .SPECIES_128 ;
42- 
43-     private  static  final  VectorSpecies <Short > SHORT_SPECIES_128  = ShortVector .SPECIES_128 ;
44-     private  static  final  VectorSpecies <Short > SHORT_SPECIES_256  = ShortVector .SPECIES_256 ;
45- 
46-     private  static  final  VectorSpecies <Integer > INT_SPECIES_128  = IntVector .SPECIES_128 ;
47-     private  static  final  VectorSpecies <Integer > INT_SPECIES_256  = IntVector .SPECIES_256 ;
48-     private  static  final  VectorSpecies <Integer > INT_SPECIES_512  = IntVector .SPECIES_512 ;
49- 
50-     private  static  final  int  VECTOR_BITSIZE ;
51-     private  static  final  VectorSpecies <Float > FLOAT_SPECIES ;
52-     private  static  final  VectorSpecies <Integer > INT_SPECIES ;
53- 
54-     static  {
55-         // default to platform supported bitsize 
56-         VECTOR_BITSIZE  = VectorShape .preferredShape ().vectorBitSize ();
57-         FLOAT_SPECIES  = VectorSpecies .of (float .class , VectorShape .forBitSize (VECTOR_BITSIZE ));
58-         INT_SPECIES  = VectorSpecies .of (int .class , VectorShape .forBitSize (VECTOR_BITSIZE ));
59-     }
60- 
61-     private  final  MemorySegment  memorySegment ;
18+ public  final  class  MemorySegmentES92Int7VectorsScorer  extends  MemorySegmentES92PanamaInt7VectorsScorer  {
6219
6320    public  MemorySegmentES92Int7VectorsScorer (IndexInput  in , int  dimensions , MemorySegment  memorySegment ) {
64-         super (in , dimensions );
65-         this .memorySegment  = memorySegment ;
21+         super (in , dimensions , memorySegment );
6622    }
6723
6824    @ Override 
69-     public  long  int7DotProduct (byte [] q ) throws  IOException  {
70-         assert  dimensions  == q .length ;
71-         int  i  = 0 ;
72-         int  res  = 0 ;
73-         // only vectorize if we'll at least enter the loop a single time 
74-         if  (dimensions  >= 16 ) {
75-             // compute vectorized dot product consistent with VPDPBUSD instruction 
76-             if  (VECTOR_BITSIZE  >= 512 ) {
77-                 i  += BYTE_SPECIES_128 .loopBound (dimensions );
78-                 res  += dotProductBody512 (q , i );
79-             } else  if  (VECTOR_BITSIZE  == 256 ) {
80-                 i  += BYTE_SPECIES_64 .loopBound (dimensions );
81-                 res  += dotProductBody256 (q , i );
82-             } else  {
83-                 // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" 
84-                 i  += BYTE_SPECIES_64 .loopBound (dimensions  - BYTE_SPECIES_64 .length ());
85-                 res  += dotProductBody128 (q , i );
86-             }
87-             // scalar tail 
88-             while  (i  < dimensions ) {
89-                 res  += in .readByte () * q [i ++];
90-             }
91-             return  res ;
92-         } else  {
93-             return  super .int7DotProduct (q );
94-         }
95-     }
96- 
97-     private  int  dotProductBody512 (byte [] q , int  limit ) throws  IOException  {
98-         IntVector  acc  = IntVector .zero (INT_SPECIES_512 );
99-         long  offset  = in .getFilePointer ();
100-         for  (int  i  = 0 ; i  < limit ; i  += BYTE_SPECIES_128 .length ()) {
101-             ByteVector  va8  = ByteVector .fromArray (BYTE_SPECIES_128 , q , i );
102-             ByteVector  vb8  = ByteVector .fromMemorySegment (BYTE_SPECIES_128 , memorySegment , offset  + i , LITTLE_ENDIAN );
103- 
104-             // 16-bit multiply: avoid AVX-512 heavy multiply on zmm 
105-             Vector <Short > va16  = va8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
106-             Vector <Short > vb16  = vb8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
107-             Vector <Short > prod16  = va16 .mul (vb16 );
108- 
109-             // 32-bit add 
110-             Vector <Integer > prod32  = prod16 .convertShape (S2I , INT_SPECIES_512 , 0 );
111-             acc  = acc .add (prod32 );
112-         }
113- 
114-         in .seek (offset  + limit ); // advance the input stream 
115-         // reduce 
116-         return  acc .reduceLanes (ADD );
25+     public  boolean  hasNativeAccess () {
26+         return  false ; // This class does not support native access 
11727    }
11828
119-     private  int  dotProductBody256 (byte [] q , int  limit ) throws  IOException  {
120-         IntVector  acc  = IntVector .zero (INT_SPECIES_256 );
121-         long  offset  = in .getFilePointer ();
122-         for  (int  i  = 0 ; i  < limit ; i  += BYTE_SPECIES_64 .length ()) {
123-             ByteVector  va8  = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
124-             ByteVector  vb8  = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset  + i , LITTLE_ENDIAN );
125- 
126-             // 32-bit multiply and add into accumulator 
127-             Vector <Integer > va32  = va8 .convertShape (B2I , INT_SPECIES_256 , 0 );
128-             Vector <Integer > vb32  = vb8 .convertShape (B2I , INT_SPECIES_256 , 0 );
129-             acc  = acc .add (va32 .mul (vb32 ));
130-         }
131-         in .seek (offset  + limit );
132-         // reduce 
133-         return  acc .reduceLanes (ADD );
134-     }
135- 
136-     private  int  dotProductBody128 (byte [] q , int  limit ) throws  IOException  {
137-         IntVector  acc  = IntVector .zero (IntVector .SPECIES_128 );
138-         long  offset  = in .getFilePointer ();
139-         // 4 bytes at a time (re-loading half the vector each time!) 
140-         for  (int  i  = 0 ; i  < limit ; i  += ByteVector .SPECIES_64 .length () >> 1 ) {
141-             // load 8 bytes 
142-             ByteVector  va8  = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
143-             ByteVector  vb8  = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset  + i , LITTLE_ENDIAN );
144- 
145-             // process first "half" only: 16-bit multiply 
146-             Vector <Short > va16  = va8 .convert (B2S , 0 );
147-             Vector <Short > vb16  = vb8 .convert (B2S , 0 );
148-             Vector <Short > prod16  = va16 .mul (vb16 );
149- 
150-             // 32-bit add 
151-             acc  = acc .add (prod16 .convertShape (S2I , IntVector .SPECIES_128 , 0 ));
152-         }
153-         in .seek (offset  + limit );
154-         // reduce 
155-         return  acc .reduceLanes (ADD );
29+     @ Override 
30+     public  long  int7DotProduct (byte [] q ) throws  IOException  {
31+         return  panamaInt7DotProduct (q );
15632    }
15733
15834    @ Override 
15935    public  void  int7DotProductBulk (byte [] q , int  count , float [] scores ) throws  IOException  {
160-         assert  dimensions  == q .length ;
161-         // only vectorize if we'll at least enter the loop a single time 
162-         if  (dimensions  >= 16 ) {
163-             // compute vectorized dot product consistent with VPDPBUSD instruction 
164-             if  (VECTOR_BITSIZE  >= 512 ) {
165-                 dotProductBody512Bulk (q , count , scores );
166-             } else  if  (VECTOR_BITSIZE  == 256 ) {
167-                 dotProductBody256Bulk (q , count , scores );
168-             } else  {
169-                 // tricky: we don't have SPECIES_32, so we workaround with "overlapping read" 
170-                 dotProductBody128Bulk (q , count , scores );
171-             }
172-         } else  {
173-             int7DotProductBulk (q , count , scores );
174-         }
175-     }
176- 
177-     private  void  dotProductBody512Bulk (byte [] q , int  count , float [] scores ) throws  IOException  {
178-         int  limit  = BYTE_SPECIES_128 .loopBound (dimensions );
179-         for  (int  iter  = 0 ; iter  < count ; iter ++) {
180-             IntVector  acc  = IntVector .zero (INT_SPECIES_512 );
181-             long  offset  = in .getFilePointer ();
182-             int  i  = 0 ;
183-             for  (; i  < limit ; i  += BYTE_SPECIES_128 .length ()) {
184-                 ByteVector  va8  = ByteVector .fromArray (BYTE_SPECIES_128 , q , i );
185-                 ByteVector  vb8  = ByteVector .fromMemorySegment (BYTE_SPECIES_128 , memorySegment , offset  + i , LITTLE_ENDIAN );
186- 
187-                 // 16-bit multiply: avoid AVX-512 heavy multiply on zmm 
188-                 Vector <Short > va16  = va8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
189-                 Vector <Short > vb16  = vb8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
190-                 Vector <Short > prod16  = va16 .mul (vb16 );
191- 
192-                 // 32-bit add 
193-                 Vector <Integer > prod32  = prod16 .convertShape (S2I , INT_SPECIES_512 , 0 );
194-                 acc  = acc .add (prod32 );
195-             }
196- 
197-             in .seek (offset  + limit ); // advance the input stream 
198-             // reduce 
199-             long  res  = acc .reduceLanes (ADD );
200-             for  (; i  < dimensions ; i ++) {
201-                 res  += in .readByte () * q [i ];
202-             }
203-             scores [iter ] = res ;
204-         }
205-     }
206- 
207-     private  void  dotProductBody256Bulk (byte [] q , int  count , float [] scores ) throws  IOException  {
208-         int  limit  = BYTE_SPECIES_128 .loopBound (dimensions );
209-         for  (int  iter  = 0 ; iter  < count ; iter ++) {
210-             IntVector  acc  = IntVector .zero (INT_SPECIES_256 );
211-             long  offset  = in .getFilePointer ();
212-             int  i  = 0 ;
213-             for  (; i  < limit ; i  += BYTE_SPECIES_64 .length ()) {
214-                 ByteVector  va8  = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
215-                 ByteVector  vb8  = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset  + i , LITTLE_ENDIAN );
216- 
217-                 // 32-bit multiply and add into accumulator 
218-                 Vector <Integer > va32  = va8 .convertShape (B2I , INT_SPECIES_256 , 0 );
219-                 Vector <Integer > vb32  = vb8 .convertShape (B2I , INT_SPECIES_256 , 0 );
220-                 acc  = acc .add (va32 .mul (vb32 ));
221-             }
222-             in .seek (offset  + limit );
223-             // reduce 
224-             long  res  = acc .reduceLanes (ADD );
225-             for  (; i  < dimensions ; i ++) {
226-                 res  += in .readByte () * q [i ];
227-             }
228-             scores [iter ] = res ;
229-         }
230-     }
231- 
232-     private  void  dotProductBody128Bulk (byte [] q , int  count , float [] scores ) throws  IOException  {
233-         int  limit  = BYTE_SPECIES_64 .loopBound (dimensions  - BYTE_SPECIES_64 .length ());
234-         for  (int  iter  = 0 ; iter  < count ; iter ++) {
235-             IntVector  acc  = IntVector .zero (IntVector .SPECIES_128 );
236-             long  offset  = in .getFilePointer ();
237-             // 4 bytes at a time (re-loading half the vector each time!) 
238-             int  i  = 0 ;
239-             for  (; i  < limit ; i  += ByteVector .SPECIES_64 .length () >> 1 ) {
240-                 // load 8 bytes 
241-                 ByteVector  va8  = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
242-                 ByteVector  vb8  = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset  + i , LITTLE_ENDIAN );
243- 
244-                 // process first "half" only: 16-bit multiply 
245-                 Vector <Short > va16  = va8 .convert (B2S , 0 );
246-                 Vector <Short > vb16  = vb8 .convert (B2S , 0 );
247-                 Vector <Short > prod16  = va16 .mul (vb16 );
248- 
249-                 // 32-bit add 
250-                 acc  = acc .add (prod16 .convertShape (S2I , IntVector .SPECIES_128 , 0 ));
251-             }
252-             in .seek (offset  + limit );
253-             // reduce 
254-             long  res  = acc .reduceLanes (ADD );
255-             for  (; i  < dimensions ; i ++) {
256-                 res  += in .readByte () * q [i ];
257-             }
258-             scores [iter ] = res ;
259-         }
36+         panamaInt7DotProductBulk (q , count , scores );
26037    }
26138
26239    @ Override 
@@ -281,72 +58,4 @@ public void scoreBulk(
28158            scores 
28259        );
28360    }
284- 
285-     private  void  applyCorrectionsBulk (
286-         float  queryLowerInterval ,
287-         float  queryUpperInterval ,
288-         int  queryComponentSum ,
289-         float  queryAdditionalCorrection ,
290-         VectorSimilarityFunction  similarityFunction ,
291-         float  centroidDp ,
292-         float [] scores 
293-     ) throws  IOException  {
294-         int  limit  = FLOAT_SPECIES .loopBound (BULK_SIZE );
295-         int  i  = 0 ;
296-         long  offset  = in .getFilePointer ();
297-         float  ay  = queryLowerInterval ;
298-         float  ly  = (queryUpperInterval  - ay ) * SEVEN_BIT_SCALE ;
299-         float  y1  = queryComponentSum ;
300-         for  (; i  < limit ; i  += FLOAT_SPECIES .length ()) {
301-             var  ax  = FloatVector .fromMemorySegment (FLOAT_SPECIES , memorySegment , offset  + i  * Float .BYTES , ByteOrder .LITTLE_ENDIAN );
302-             var  lx  = FloatVector .fromMemorySegment (
303-                 FLOAT_SPECIES ,
304-                 memorySegment ,
305-                 offset  + 4  * BULK_SIZE  + i  * Float .BYTES ,
306-                 ByteOrder .LITTLE_ENDIAN 
307-             ).sub (ax ).mul (SEVEN_BIT_SCALE );
308-             var  targetComponentSums  = IntVector .fromMemorySegment (
309-                 INT_SPECIES ,
310-                 memorySegment ,
311-                 offset  + 8  * BULK_SIZE  + i  * Integer .BYTES ,
312-                 ByteOrder .LITTLE_ENDIAN 
313-             ).convert (VectorOperators .I2F , 0 );
314-             var  additionalCorrections  = FloatVector .fromMemorySegment (
315-                 FLOAT_SPECIES ,
316-                 memorySegment ,
317-                 offset  + 12  * BULK_SIZE  + i  * Float .BYTES ,
318-                 ByteOrder .LITTLE_ENDIAN 
319-             );
320-             var  qcDist  = FloatVector .fromArray (FLOAT_SPECIES , scores , i );
321-             // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly * 
322-             // qcDist; 
323-             var  res1  = ax .mul (ay ).mul (dimensions );
324-             var  res2  = lx .mul (ay ).mul (targetComponentSums );
325-             var  res3  = ax .mul (ly ).mul (y1 );
326-             var  res4  = lx .mul (ly ).mul (qcDist );
327-             var  res  = res1 .add (res2 ).add (res3 ).add (res4 );
328-             // For euclidean, we need to invert the score and apply the additional correction, which is 
329-             // assumed to be the squared l2norm of the centroid centered vectors. 
330-             if  (similarityFunction  == EUCLIDEAN ) {
331-                 res  = res .mul (-2 ).add (additionalCorrections ).add (queryAdditionalCorrection ).add (1f );
332-                 res  = FloatVector .broadcast (FLOAT_SPECIES , 1 ).div (res ).max (0 );
333-                 res .intoArray (scores , i );
334-             } else  {
335-                 // For cosine and max inner product, we need to apply the additional correction, which is 
336-                 // assumed to be the non-centered dot-product between the vector and the centroid 
337-                 res  = res .add (queryAdditionalCorrection ).add (additionalCorrections ).sub (centroidDp );
338-                 if  (similarityFunction  == MAXIMUM_INNER_PRODUCT ) {
339-                     res .intoArray (scores , i );
340-                     // not sure how to do it better 
341-                     for  (int  j  = 0 ; j  < FLOAT_SPECIES .length (); j ++) {
342-                         scores [i  + j ] = VectorUtil .scaleMaxInnerProductScore (scores [i  + j ]);
343-                     }
344-                 } else  {
345-                     res  = res .add (1f ).mul (0.5f ).max (0 );
346-                     res .intoArray (scores , i );
347-                 }
348-             }
349-         }
350-         in .seek (offset  + 16L  * BULK_SIZE );
351-     }
35261}
0 commit comments