88 */
99package org .elasticsearch .simdvec .internal ;
1010
11- import jdk .incubator .vector .ByteVector ;
12- import jdk .incubator .vector .FloatVector ;
13- import jdk .incubator .vector .IntVector ;
14- import jdk .incubator .vector .ShortVector ;
15- import jdk .incubator .vector .Vector ;
16- import jdk .incubator .vector .VectorOperators ;
17- import jdk .incubator .vector .VectorShape ;
18- import jdk .incubator .vector .VectorSpecies ;
19-
2011import org .apache .lucene .index .VectorSimilarityFunction ;
2112import org .apache .lucene .store .IndexInput ;
22- import org .apache .lucene .util .VectorUtil ;
23- import org .elasticsearch .simdvec .ES92Int7VectorsScorer ;
2413
2514import java .io .IOException ;
2615import java .lang .foreign .MemorySegment ;
27- import java .nio .ByteOrder ;
28-
29- import static java .nio .ByteOrder .LITTLE_ENDIAN ;
30- import static jdk .incubator .vector .VectorOperators .ADD ;
31- import static jdk .incubator .vector .VectorOperators .B2I ;
32- import static jdk .incubator .vector .VectorOperators .B2S ;
33- import static jdk .incubator .vector .VectorOperators .S2I ;
34- import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
35- import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
3616
3717/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
38- public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer {
39-
40- private static final VectorSpecies <Byte > BYTE_SPECIES_64 = ByteVector .SPECIES_64 ;
41- private static final VectorSpecies <Byte > BYTE_SPECIES_128 = ByteVector .SPECIES_128 ;
42-
43- private static final VectorSpecies <Short > SHORT_SPECIES_128 = ShortVector .SPECIES_128 ;
44- private static final VectorSpecies <Short > SHORT_SPECIES_256 = ShortVector .SPECIES_256 ;
45-
46- private static final VectorSpecies <Integer > INT_SPECIES_128 = IntVector .SPECIES_128 ;
47- private static final VectorSpecies <Integer > INT_SPECIES_256 = IntVector .SPECIES_256 ;
48- private static final VectorSpecies <Integer > INT_SPECIES_512 = IntVector .SPECIES_512 ;
49-
50- private static final int VECTOR_BITSIZE ;
51- private static final VectorSpecies <Float > FLOAT_SPECIES ;
52- private static final VectorSpecies <Integer > INT_SPECIES ;
53-
54- static {
55- // default to platform supported bitsize
56- VECTOR_BITSIZE = VectorShape .preferredShape ().vectorBitSize ();
57- FLOAT_SPECIES = VectorSpecies .of (float .class , VectorShape .forBitSize (VECTOR_BITSIZE ));
58- INT_SPECIES = VectorSpecies .of (int .class , VectorShape .forBitSize (VECTOR_BITSIZE ));
59- }
60-
61- private final MemorySegment memorySegment ;
18+ public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer {
6219
6320 public MemorySegmentES92Int7VectorsScorer (IndexInput in , int dimensions , MemorySegment memorySegment ) {
64- super (in , dimensions );
65- this .memorySegment = memorySegment ;
21+ super (in , dimensions , memorySegment );
6622 }
6723
6824 @ Override
69- public long int7DotProduct (byte [] q ) throws IOException {
70- assert dimensions == q .length ;
71- int i = 0 ;
72- int res = 0 ;
73- // only vectorize if we'll at least enter the loop a single time
74- if (dimensions >= 16 ) {
75- // compute vectorized dot product consistent with VPDPBUSD instruction
76- if (VECTOR_BITSIZE >= 512 ) {
77- i += BYTE_SPECIES_128 .loopBound (dimensions );
78- res += dotProductBody512 (q , i );
79- } else if (VECTOR_BITSIZE == 256 ) {
80- i += BYTE_SPECIES_64 .loopBound (dimensions );
81- res += dotProductBody256 (q , i );
82- } else {
83- // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
84- i += BYTE_SPECIES_64 .loopBound (dimensions - BYTE_SPECIES_64 .length ());
85- res += dotProductBody128 (q , i );
86- }
87- // scalar tail
88- while (i < dimensions ) {
89- res += in .readByte () * q [i ++];
90- }
91- return res ;
92- } else {
93- return super .int7DotProduct (q );
94- }
95- }
96-
97- private int dotProductBody512 (byte [] q , int limit ) throws IOException {
98- IntVector acc = IntVector .zero (INT_SPECIES_512 );
99- long offset = in .getFilePointer ();
100- for (int i = 0 ; i < limit ; i += BYTE_SPECIES_128 .length ()) {
101- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_128 , q , i );
102- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_128 , memorySegment , offset + i , LITTLE_ENDIAN );
103-
104- // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
105- Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
106- Vector <Short > vb16 = vb8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
107- Vector <Short > prod16 = va16 .mul (vb16 );
108-
109- // 32-bit add
110- Vector <Integer > prod32 = prod16 .convertShape (S2I , INT_SPECIES_512 , 0 );
111- acc = acc .add (prod32 );
112- }
113-
114- in .seek (offset + limit ); // advance the input stream
115- // reduce
116- return acc .reduceLanes (ADD );
25+ public boolean hasNativeAccess () {
26+ return false ; // This class does not support native access
11727 }
11828
119- private int dotProductBody256 (byte [] q , int limit ) throws IOException {
120- IntVector acc = IntVector .zero (INT_SPECIES_256 );
121- long offset = in .getFilePointer ();
122- for (int i = 0 ; i < limit ; i += BYTE_SPECIES_64 .length ()) {
123- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
124- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
125-
126- // 32-bit multiply and add into accumulator
127- Vector <Integer > va32 = va8 .convertShape (B2I , INT_SPECIES_256 , 0 );
128- Vector <Integer > vb32 = vb8 .convertShape (B2I , INT_SPECIES_256 , 0 );
129- acc = acc .add (va32 .mul (vb32 ));
130- }
131- in .seek (offset + limit );
132- // reduce
133- return acc .reduceLanes (ADD );
134- }
135-
136- private int dotProductBody128 (byte [] q , int limit ) throws IOException {
137- IntVector acc = IntVector .zero (IntVector .SPECIES_128 );
138- long offset = in .getFilePointer ();
139- // 4 bytes at a time (re-loading half the vector each time!)
140- for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
141- // load 8 bytes
142- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
143- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
144-
145- // process first "half" only: 16-bit multiply
146- Vector <Short > va16 = va8 .convert (B2S , 0 );
147- Vector <Short > vb16 = vb8 .convert (B2S , 0 );
148- Vector <Short > prod16 = va16 .mul (vb16 );
149-
150- // 32-bit add
151- acc = acc .add (prod16 .convertShape (S2I , IntVector .SPECIES_128 , 0 ));
152- }
153- in .seek (offset + limit );
154- // reduce
155- return acc .reduceLanes (ADD );
29+ @ Override
30+ public long int7DotProduct (byte [] q ) throws IOException {
31+ return panamaInt7DotProduct (q );
15632 }
15733
15834 @ Override
15935 public void int7DotProductBulk (byte [] q , int count , float [] scores ) throws IOException {
160- assert dimensions == q .length ;
161- // only vectorize if we'll at least enter the loop a single time
162- if (dimensions >= 16 ) {
163- // compute vectorized dot product consistent with VPDPBUSD instruction
164- if (VECTOR_BITSIZE >= 512 ) {
165- dotProductBody512Bulk (q , count , scores );
166- } else if (VECTOR_BITSIZE == 256 ) {
167- dotProductBody256Bulk (q , count , scores );
168- } else {
169- // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
170- dotProductBody128Bulk (q , count , scores );
171- }
172- } else {
173- int7DotProductBulk (q , count , scores );
174- }
175- }
176-
177- private void dotProductBody512Bulk (byte [] q , int count , float [] scores ) throws IOException {
178- int limit = BYTE_SPECIES_128 .loopBound (dimensions );
179- for (int iter = 0 ; iter < count ; iter ++) {
180- IntVector acc = IntVector .zero (INT_SPECIES_512 );
181- long offset = in .getFilePointer ();
182- int i = 0 ;
183- for (; i < limit ; i += BYTE_SPECIES_128 .length ()) {
184- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_128 , q , i );
185- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_128 , memorySegment , offset + i , LITTLE_ENDIAN );
186-
187- // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
188- Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
189- Vector <Short > vb16 = vb8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
190- Vector <Short > prod16 = va16 .mul (vb16 );
191-
192- // 32-bit add
193- Vector <Integer > prod32 = prod16 .convertShape (S2I , INT_SPECIES_512 , 0 );
194- acc = acc .add (prod32 );
195- }
196-
197- in .seek (offset + limit ); // advance the input stream
198- // reduce
199- long res = acc .reduceLanes (ADD );
200- for (; i < dimensions ; i ++) {
201- res += in .readByte () * q [i ];
202- }
203- scores [iter ] = res ;
204- }
205- }
206-
207- private void dotProductBody256Bulk (byte [] q , int count , float [] scores ) throws IOException {
208- int limit = BYTE_SPECIES_128 .loopBound (dimensions );
209- for (int iter = 0 ; iter < count ; iter ++) {
210- IntVector acc = IntVector .zero (INT_SPECIES_256 );
211- long offset = in .getFilePointer ();
212- int i = 0 ;
213- for (; i < limit ; i += BYTE_SPECIES_64 .length ()) {
214- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
215- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
216-
217- // 32-bit multiply and add into accumulator
218- Vector <Integer > va32 = va8 .convertShape (B2I , INT_SPECIES_256 , 0 );
219- Vector <Integer > vb32 = vb8 .convertShape (B2I , INT_SPECIES_256 , 0 );
220- acc = acc .add (va32 .mul (vb32 ));
221- }
222- in .seek (offset + limit );
223- // reduce
224- long res = acc .reduceLanes (ADD );
225- for (; i < dimensions ; i ++) {
226- res += in .readByte () * q [i ];
227- }
228- scores [iter ] = res ;
229- }
230- }
231-
232- private void dotProductBody128Bulk (byte [] q , int count , float [] scores ) throws IOException {
233- int limit = BYTE_SPECIES_64 .loopBound (dimensions - BYTE_SPECIES_64 .length ());
234- for (int iter = 0 ; iter < count ; iter ++) {
235- IntVector acc = IntVector .zero (IntVector .SPECIES_128 );
236- long offset = in .getFilePointer ();
237- // 4 bytes at a time (re-loading half the vector each time!)
238- int i = 0 ;
239- for (; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
240- // load 8 bytes
241- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
242- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
243-
244- // process first "half" only: 16-bit multiply
245- Vector <Short > va16 = va8 .convert (B2S , 0 );
246- Vector <Short > vb16 = vb8 .convert (B2S , 0 );
247- Vector <Short > prod16 = va16 .mul (vb16 );
248-
249- // 32-bit add
250- acc = acc .add (prod16 .convertShape (S2I , IntVector .SPECIES_128 , 0 ));
251- }
252- in .seek (offset + limit );
253- // reduce
254- long res = acc .reduceLanes (ADD );
255- for (; i < dimensions ; i ++) {
256- res += in .readByte () * q [i ];
257- }
258- scores [iter ] = res ;
259- }
36+ panamaInt7DotProductBulk (q , count , scores );
26037 }
26138
26239 @ Override
@@ -281,72 +58,4 @@ public void scoreBulk(
28158 scores
28259 );
28360 }
284-
285- private void applyCorrectionsBulk (
286- float queryLowerInterval ,
287- float queryUpperInterval ,
288- int queryComponentSum ,
289- float queryAdditionalCorrection ,
290- VectorSimilarityFunction similarityFunction ,
291- float centroidDp ,
292- float [] scores
293- ) throws IOException {
294- int limit = FLOAT_SPECIES .loopBound (BULK_SIZE );
295- int i = 0 ;
296- long offset = in .getFilePointer ();
297- float ay = queryLowerInterval ;
298- float ly = (queryUpperInterval - ay ) * SEVEN_BIT_SCALE ;
299- float y1 = queryComponentSum ;
300- for (; i < limit ; i += FLOAT_SPECIES .length ()) {
301- var ax = FloatVector .fromMemorySegment (FLOAT_SPECIES , memorySegment , offset + i * Float .BYTES , ByteOrder .LITTLE_ENDIAN );
302- var lx = FloatVector .fromMemorySegment (
303- FLOAT_SPECIES ,
304- memorySegment ,
305- offset + 4 * BULK_SIZE + i * Float .BYTES ,
306- ByteOrder .LITTLE_ENDIAN
307- ).sub (ax ).mul (SEVEN_BIT_SCALE );
308- var targetComponentSums = IntVector .fromMemorySegment (
309- INT_SPECIES ,
310- memorySegment ,
311- offset + 8 * BULK_SIZE + i * Integer .BYTES ,
312- ByteOrder .LITTLE_ENDIAN
313- ).convert (VectorOperators .I2F , 0 );
314- var additionalCorrections = FloatVector .fromMemorySegment (
315- FLOAT_SPECIES ,
316- memorySegment ,
317- offset + 12 * BULK_SIZE + i * Float .BYTES ,
318- ByteOrder .LITTLE_ENDIAN
319- );
320- var qcDist = FloatVector .fromArray (FLOAT_SPECIES , scores , i );
321- // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
322- // qcDist;
323- var res1 = ax .mul (ay ).mul (dimensions );
324- var res2 = lx .mul (ay ).mul (targetComponentSums );
325- var res3 = ax .mul (ly ).mul (y1 );
326- var res4 = lx .mul (ly ).mul (qcDist );
327- var res = res1 .add (res2 ).add (res3 ).add (res4 );
328- // For euclidean, we need to invert the score and apply the additional correction, which is
329- // assumed to be the squared l2norm of the centroid centered vectors.
330- if (similarityFunction == EUCLIDEAN ) {
331- res = res .mul (-2 ).add (additionalCorrections ).add (queryAdditionalCorrection ).add (1f );
332- res = FloatVector .broadcast (FLOAT_SPECIES , 1 ).div (res ).max (0 );
333- res .intoArray (scores , i );
334- } else {
335- // For cosine and max inner product, we need to apply the additional correction, which is
336- // assumed to be the non-centered dot-product between the vector and the centroid
337- res = res .add (queryAdditionalCorrection ).add (additionalCorrections ).sub (centroidDp );
338- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
339- res .intoArray (scores , i );
340- // not sure how to do it better
341- for (int j = 0 ; j < FLOAT_SPECIES .length (); j ++) {
342- scores [i + j ] = VectorUtil .scaleMaxInnerProductScore (scores [i + j ]);
343- }
344- } else {
345- res = res .add (1f ).mul (0.5f ).max (0 );
346- res .intoArray (scores , i );
347- }
348- }
349- }
350- in .seek (offset + 16L * BULK_SIZE );
351- }
35261}
0 commit comments