8
8
*/
9
9
package org .elasticsearch .simdvec .internal ;
10
10
11
- import jdk .incubator .vector .ByteVector ;
12
- import jdk .incubator .vector .FloatVector ;
13
- import jdk .incubator .vector .IntVector ;
14
- import jdk .incubator .vector .ShortVector ;
15
- import jdk .incubator .vector .Vector ;
16
- import jdk .incubator .vector .VectorOperators ;
17
- import jdk .incubator .vector .VectorShape ;
18
- import jdk .incubator .vector .VectorSpecies ;
19
-
20
11
import org .apache .lucene .index .VectorSimilarityFunction ;
21
12
import org .apache .lucene .store .IndexInput ;
22
- import org .apache .lucene .util .VectorUtil ;
23
- import org .elasticsearch .simdvec .ES92Int7VectorsScorer ;
24
13
25
14
import java .io .IOException ;
26
15
import java .lang .foreign .MemorySegment ;
27
- import java .nio .ByteOrder ;
28
-
29
- import static java .nio .ByteOrder .LITTLE_ENDIAN ;
30
- import static jdk .incubator .vector .VectorOperators .ADD ;
31
- import static jdk .incubator .vector .VectorOperators .B2I ;
32
- import static jdk .incubator .vector .VectorOperators .B2S ;
33
- import static jdk .incubator .vector .VectorOperators .S2I ;
34
- import static org .apache .lucene .index .VectorSimilarityFunction .EUCLIDEAN ;
35
- import static org .apache .lucene .index .VectorSimilarityFunction .MAXIMUM_INNER_PRODUCT ;
36
16
37
17
/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
38
- public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer {
39
-
40
- private static final VectorSpecies <Byte > BYTE_SPECIES_64 = ByteVector .SPECIES_64 ;
41
- private static final VectorSpecies <Byte > BYTE_SPECIES_128 = ByteVector .SPECIES_128 ;
42
-
43
- private static final VectorSpecies <Short > SHORT_SPECIES_128 = ShortVector .SPECIES_128 ;
44
- private static final VectorSpecies <Short > SHORT_SPECIES_256 = ShortVector .SPECIES_256 ;
45
-
46
- private static final VectorSpecies <Integer > INT_SPECIES_128 = IntVector .SPECIES_128 ;
47
- private static final VectorSpecies <Integer > INT_SPECIES_256 = IntVector .SPECIES_256 ;
48
- private static final VectorSpecies <Integer > INT_SPECIES_512 = IntVector .SPECIES_512 ;
49
-
50
- private static final int VECTOR_BITSIZE ;
51
- private static final VectorSpecies <Float > FLOAT_SPECIES ;
52
- private static final VectorSpecies <Integer > INT_SPECIES ;
53
-
54
- static {
55
- // default to platform supported bitsize
56
- VECTOR_BITSIZE = VectorShape .preferredShape ().vectorBitSize ();
57
- FLOAT_SPECIES = VectorSpecies .of (float .class , VectorShape .forBitSize (VECTOR_BITSIZE ));
58
- INT_SPECIES = VectorSpecies .of (int .class , VectorShape .forBitSize (VECTOR_BITSIZE ));
59
- }
60
-
61
- private final MemorySegment memorySegment ;
18
+ public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer {
62
19
63
20
public MemorySegmentES92Int7VectorsScorer (IndexInput in , int dimensions , MemorySegment memorySegment ) {
64
- super (in , dimensions );
65
- this .memorySegment = memorySegment ;
21
+ super (in , dimensions , memorySegment );
66
22
}
67
23
68
24
@ Override
69
- public long int7DotProduct (byte [] q ) throws IOException {
70
- assert dimensions == q .length ;
71
- int i = 0 ;
72
- int res = 0 ;
73
- // only vectorize if we'll at least enter the loop a single time
74
- if (dimensions >= 16 ) {
75
- // compute vectorized dot product consistent with VPDPBUSD instruction
76
- if (VECTOR_BITSIZE >= 512 ) {
77
- i += BYTE_SPECIES_128 .loopBound (dimensions );
78
- res += dotProductBody512 (q , i );
79
- } else if (VECTOR_BITSIZE == 256 ) {
80
- i += BYTE_SPECIES_64 .loopBound (dimensions );
81
- res += dotProductBody256 (q , i );
82
- } else {
83
- // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
84
- i += BYTE_SPECIES_64 .loopBound (dimensions - BYTE_SPECIES_64 .length ());
85
- res += dotProductBody128 (q , i );
86
- }
87
- // scalar tail
88
- while (i < dimensions ) {
89
- res += in .readByte () * q [i ++];
90
- }
91
- return res ;
92
- } else {
93
- return super .int7DotProduct (q );
94
- }
95
- }
96
-
97
- private int dotProductBody512 (byte [] q , int limit ) throws IOException {
98
- IntVector acc = IntVector .zero (INT_SPECIES_512 );
99
- long offset = in .getFilePointer ();
100
- for (int i = 0 ; i < limit ; i += BYTE_SPECIES_128 .length ()) {
101
- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_128 , q , i );
102
- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_128 , memorySegment , offset + i , LITTLE_ENDIAN );
103
-
104
- // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
105
- Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
106
- Vector <Short > vb16 = vb8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
107
- Vector <Short > prod16 = va16 .mul (vb16 );
108
-
109
- // 32-bit add
110
- Vector <Integer > prod32 = prod16 .convertShape (S2I , INT_SPECIES_512 , 0 );
111
- acc = acc .add (prod32 );
112
- }
113
-
114
- in .seek (offset + limit ); // advance the input stream
115
- // reduce
116
- return acc .reduceLanes (ADD );
25
+ public boolean hasNativeAccess () {
26
+ return false ; // This class does not support native access
117
27
}
118
28
119
- private int dotProductBody256 (byte [] q , int limit ) throws IOException {
120
- IntVector acc = IntVector .zero (INT_SPECIES_256 );
121
- long offset = in .getFilePointer ();
122
- for (int i = 0 ; i < limit ; i += BYTE_SPECIES_64 .length ()) {
123
- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
124
- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
125
-
126
- // 32-bit multiply and add into accumulator
127
- Vector <Integer > va32 = va8 .convertShape (B2I , INT_SPECIES_256 , 0 );
128
- Vector <Integer > vb32 = vb8 .convertShape (B2I , INT_SPECIES_256 , 0 );
129
- acc = acc .add (va32 .mul (vb32 ));
130
- }
131
- in .seek (offset + limit );
132
- // reduce
133
- return acc .reduceLanes (ADD );
134
- }
135
-
136
- private int dotProductBody128 (byte [] q , int limit ) throws IOException {
137
- IntVector acc = IntVector .zero (IntVector .SPECIES_128 );
138
- long offset = in .getFilePointer ();
139
- // 4 bytes at a time (re-loading half the vector each time!)
140
- for (int i = 0 ; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
141
- // load 8 bytes
142
- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
143
- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
144
-
145
- // process first "half" only: 16-bit multiply
146
- Vector <Short > va16 = va8 .convert (B2S , 0 );
147
- Vector <Short > vb16 = vb8 .convert (B2S , 0 );
148
- Vector <Short > prod16 = va16 .mul (vb16 );
149
-
150
- // 32-bit add
151
- acc = acc .add (prod16 .convertShape (S2I , IntVector .SPECIES_128 , 0 ));
152
- }
153
- in .seek (offset + limit );
154
- // reduce
155
- return acc .reduceLanes (ADD );
29
+ @ Override
30
+ public long int7DotProduct (byte [] q ) throws IOException {
31
+ return panamaInt7DotProduct (q );
156
32
}
157
33
158
34
@ Override
159
35
public void int7DotProductBulk (byte [] q , int count , float [] scores ) throws IOException {
160
- assert dimensions == q .length ;
161
- // only vectorize if we'll at least enter the loop a single time
162
- if (dimensions >= 16 ) {
163
- // compute vectorized dot product consistent with VPDPBUSD instruction
164
- if (VECTOR_BITSIZE >= 512 ) {
165
- dotProductBody512Bulk (q , count , scores );
166
- } else if (VECTOR_BITSIZE == 256 ) {
167
- dotProductBody256Bulk (q , count , scores );
168
- } else {
169
- // tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
170
- dotProductBody128Bulk (q , count , scores );
171
- }
172
- } else {
173
- int7DotProductBulk (q , count , scores );
174
- }
175
- }
176
-
177
- private void dotProductBody512Bulk (byte [] q , int count , float [] scores ) throws IOException {
178
- int limit = BYTE_SPECIES_128 .loopBound (dimensions );
179
- for (int iter = 0 ; iter < count ; iter ++) {
180
- IntVector acc = IntVector .zero (INT_SPECIES_512 );
181
- long offset = in .getFilePointer ();
182
- int i = 0 ;
183
- for (; i < limit ; i += BYTE_SPECIES_128 .length ()) {
184
- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_128 , q , i );
185
- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_128 , memorySegment , offset + i , LITTLE_ENDIAN );
186
-
187
- // 16-bit multiply: avoid AVX-512 heavy multiply on zmm
188
- Vector <Short > va16 = va8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
189
- Vector <Short > vb16 = vb8 .convertShape (B2S , SHORT_SPECIES_256 , 0 );
190
- Vector <Short > prod16 = va16 .mul (vb16 );
191
-
192
- // 32-bit add
193
- Vector <Integer > prod32 = prod16 .convertShape (S2I , INT_SPECIES_512 , 0 );
194
- acc = acc .add (prod32 );
195
- }
196
-
197
- in .seek (offset + limit ); // advance the input stream
198
- // reduce
199
- long res = acc .reduceLanes (ADD );
200
- for (; i < dimensions ; i ++) {
201
- res += in .readByte () * q [i ];
202
- }
203
- scores [iter ] = res ;
204
- }
205
- }
206
-
207
- private void dotProductBody256Bulk (byte [] q , int count , float [] scores ) throws IOException {
208
- int limit = BYTE_SPECIES_128 .loopBound (dimensions );
209
- for (int iter = 0 ; iter < count ; iter ++) {
210
- IntVector acc = IntVector .zero (INT_SPECIES_256 );
211
- long offset = in .getFilePointer ();
212
- int i = 0 ;
213
- for (; i < limit ; i += BYTE_SPECIES_64 .length ()) {
214
- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
215
- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
216
-
217
- // 32-bit multiply and add into accumulator
218
- Vector <Integer > va32 = va8 .convertShape (B2I , INT_SPECIES_256 , 0 );
219
- Vector <Integer > vb32 = vb8 .convertShape (B2I , INT_SPECIES_256 , 0 );
220
- acc = acc .add (va32 .mul (vb32 ));
221
- }
222
- in .seek (offset + limit );
223
- // reduce
224
- long res = acc .reduceLanes (ADD );
225
- for (; i < dimensions ; i ++) {
226
- res += in .readByte () * q [i ];
227
- }
228
- scores [iter ] = res ;
229
- }
230
- }
231
-
232
- private void dotProductBody128Bulk (byte [] q , int count , float [] scores ) throws IOException {
233
- int limit = BYTE_SPECIES_64 .loopBound (dimensions - BYTE_SPECIES_64 .length ());
234
- for (int iter = 0 ; iter < count ; iter ++) {
235
- IntVector acc = IntVector .zero (IntVector .SPECIES_128 );
236
- long offset = in .getFilePointer ();
237
- // 4 bytes at a time (re-loading half the vector each time!)
238
- int i = 0 ;
239
- for (; i < limit ; i += ByteVector .SPECIES_64 .length () >> 1 ) {
240
- // load 8 bytes
241
- ByteVector va8 = ByteVector .fromArray (BYTE_SPECIES_64 , q , i );
242
- ByteVector vb8 = ByteVector .fromMemorySegment (BYTE_SPECIES_64 , memorySegment , offset + i , LITTLE_ENDIAN );
243
-
244
- // process first "half" only: 16-bit multiply
245
- Vector <Short > va16 = va8 .convert (B2S , 0 );
246
- Vector <Short > vb16 = vb8 .convert (B2S , 0 );
247
- Vector <Short > prod16 = va16 .mul (vb16 );
248
-
249
- // 32-bit add
250
- acc = acc .add (prod16 .convertShape (S2I , IntVector .SPECIES_128 , 0 ));
251
- }
252
- in .seek (offset + limit );
253
- // reduce
254
- long res = acc .reduceLanes (ADD );
255
- for (; i < dimensions ; i ++) {
256
- res += in .readByte () * q [i ];
257
- }
258
- scores [iter ] = res ;
259
- }
36
+ panamaInt7DotProductBulk (q , count , scores );
260
37
}
261
38
262
39
@ Override
@@ -281,72 +58,4 @@ public void scoreBulk(
281
58
scores
282
59
);
283
60
}
284
-
285
- private void applyCorrectionsBulk (
286
- float queryLowerInterval ,
287
- float queryUpperInterval ,
288
- int queryComponentSum ,
289
- float queryAdditionalCorrection ,
290
- VectorSimilarityFunction similarityFunction ,
291
- float centroidDp ,
292
- float [] scores
293
- ) throws IOException {
294
- int limit = FLOAT_SPECIES .loopBound (BULK_SIZE );
295
- int i = 0 ;
296
- long offset = in .getFilePointer ();
297
- float ay = queryLowerInterval ;
298
- float ly = (queryUpperInterval - ay ) * SEVEN_BIT_SCALE ;
299
- float y1 = queryComponentSum ;
300
- for (; i < limit ; i += FLOAT_SPECIES .length ()) {
301
- var ax = FloatVector .fromMemorySegment (FLOAT_SPECIES , memorySegment , offset + i * Float .BYTES , ByteOrder .LITTLE_ENDIAN );
302
- var lx = FloatVector .fromMemorySegment (
303
- FLOAT_SPECIES ,
304
- memorySegment ,
305
- offset + 4 * BULK_SIZE + i * Float .BYTES ,
306
- ByteOrder .LITTLE_ENDIAN
307
- ).sub (ax ).mul (SEVEN_BIT_SCALE );
308
- var targetComponentSums = IntVector .fromMemorySegment (
309
- INT_SPECIES ,
310
- memorySegment ,
311
- offset + 8 * BULK_SIZE + i * Integer .BYTES ,
312
- ByteOrder .LITTLE_ENDIAN
313
- ).convert (VectorOperators .I2F , 0 );
314
- var additionalCorrections = FloatVector .fromMemorySegment (
315
- FLOAT_SPECIES ,
316
- memorySegment ,
317
- offset + 12 * BULK_SIZE + i * Float .BYTES ,
318
- ByteOrder .LITTLE_ENDIAN
319
- );
320
- var qcDist = FloatVector .fromArray (FLOAT_SPECIES , scores , i );
321
- // ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
322
- // qcDist;
323
- var res1 = ax .mul (ay ).mul (dimensions );
324
- var res2 = lx .mul (ay ).mul (targetComponentSums );
325
- var res3 = ax .mul (ly ).mul (y1 );
326
- var res4 = lx .mul (ly ).mul (qcDist );
327
- var res = res1 .add (res2 ).add (res3 ).add (res4 );
328
- // For euclidean, we need to invert the score and apply the additional correction, which is
329
- // assumed to be the squared l2norm of the centroid centered vectors.
330
- if (similarityFunction == EUCLIDEAN ) {
331
- res = res .mul (-2 ).add (additionalCorrections ).add (queryAdditionalCorrection ).add (1f );
332
- res = FloatVector .broadcast (FLOAT_SPECIES , 1 ).div (res ).max (0 );
333
- res .intoArray (scores , i );
334
- } else {
335
- // For cosine and max inner product, we need to apply the additional correction, which is
336
- // assumed to be the non-centered dot-product between the vector and the centroid
337
- res = res .add (queryAdditionalCorrection ).add (additionalCorrections ).sub (centroidDp );
338
- if (similarityFunction == MAXIMUM_INNER_PRODUCT ) {
339
- res .intoArray (scores , i );
340
- // not sure how to do it better
341
- for (int j = 0 ; j < FLOAT_SPECIES .length (); j ++) {
342
- scores [i + j ] = VectorUtil .scaleMaxInnerProductScore (scores [i + j ]);
343
- }
344
- } else {
345
- res = res .add (1f ).mul (0.5f ).max (0 );
346
- res .intoArray (scores , i );
347
- }
348
- }
349
- }
350
- in .seek (offset + 16L * BULK_SIZE );
351
- }
352
61
}
0 commit comments