Skip to content

Commit 6ca0466

Browse files
authored
Use panamized version for windows in Int7VectorScorer (#132311)
1 parent 7d1f135 commit 6ca0466

File tree

5 files changed

+383
-395
lines changed

5 files changed

+383
-395
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ public ES92Int7VectorsScorer(IndexInput in, int dimensions) {
4141
this.dimensions = dimensions;
4242
}
4343

44+
/**
45+
* Checks if the current implementation supports fast native access.
46+
*/
47+
public boolean hasNativeAccess() {
48+
return false; // This class does not support native access
49+
}
50+
4451
/**
4552
* compute the quantize distance between the provided quantized query and the quantized vector
4653
* that is read from the wrapped {@link IndexInput}.

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java

Lines changed: 8 additions & 299 deletions
Original file line numberDiff line numberDiff line change
@@ -8,255 +8,32 @@
88
*/
99
package org.elasticsearch.simdvec.internal;
1010

11-
import jdk.incubator.vector.ByteVector;
12-
import jdk.incubator.vector.FloatVector;
13-
import jdk.incubator.vector.IntVector;
14-
import jdk.incubator.vector.ShortVector;
15-
import jdk.incubator.vector.Vector;
16-
import jdk.incubator.vector.VectorOperators;
17-
import jdk.incubator.vector.VectorShape;
18-
import jdk.incubator.vector.VectorSpecies;
19-
2011
import org.apache.lucene.index.VectorSimilarityFunction;
2112
import org.apache.lucene.store.IndexInput;
22-
import org.apache.lucene.util.VectorUtil;
23-
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
2413

2514
import java.io.IOException;
2615
import java.lang.foreign.MemorySegment;
27-
import java.nio.ByteOrder;
28-
29-
import static java.nio.ByteOrder.LITTLE_ENDIAN;
30-
import static jdk.incubator.vector.VectorOperators.ADD;
31-
import static jdk.incubator.vector.VectorOperators.B2I;
32-
import static jdk.incubator.vector.VectorOperators.B2S;
33-
import static jdk.incubator.vector.VectorOperators.S2I;
34-
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
35-
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
3616

3717
/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
38-
public final class MemorySegmentES92Int7VectorsScorer extends ES92Int7VectorsScorer {
39-
40-
private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
41-
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
42-
43-
private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
44-
private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
45-
46-
private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
47-
private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
48-
private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
49-
50-
private static final int VECTOR_BITSIZE;
51-
private static final VectorSpecies<Float> FLOAT_SPECIES;
52-
private static final VectorSpecies<Integer> INT_SPECIES;
53-
54-
static {
55-
// default to platform supported bitsize
56-
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
57-
FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
58-
INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE));
59-
}
60-
61-
private final MemorySegment memorySegment;
18+
public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer {
6219

6320
public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
64-
super(in, dimensions);
65-
this.memorySegment = memorySegment;
21+
super(in, dimensions, memorySegment);
6622
}
6723

6824
@Override
69-
public long int7DotProduct(byte[] q) throws IOException {
70-
assert dimensions == q.length;
71-
int i = 0;
72-
int res = 0;
73-
// only vectorize if we'll at least enter the loop a single time
74-
if (dimensions >= 16) {
75-
// compute vectorized dot product consistent with VPDPBUSD instruction
76-
if (VECTOR_BITSIZE >= 512) {
77-
i += BYTE_SPECIES_128.loopBound(dimensions);
78-
res += dotProductBody512(q, i);
79-
} else if (VECTOR_BITSIZE == 256) {
80-
i += BYTE_SPECIES_64.loopBound(dimensions);
81-
res += dotProductBody256(q, i);
82-
} else {
83-
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
84-
i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
85-
res += dotProductBody128(q, i);
86-
}
87-
// scalar tail
88-
while (i < dimensions) {
89-
res += in.readByte() * q[i++];
90-
}
91-
return res;
92-
} else {
93-
return super.int7DotProduct(q);
94-
}
95-
}
96-
97-
private int dotProductBody512(byte[] q, int limit) throws IOException {
98-
IntVector acc = IntVector.zero(INT_SPECIES_512);
99-
long offset = in.getFilePointer();
100-
for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
101-
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
102-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
103-
104-
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
105-
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
106-
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
107-
Vector<Short> prod16 = va16.mul(vb16);
108-
109-
// 32-bit add
110-
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
111-
acc = acc.add(prod32);
112-
}
113-
114-
in.seek(offset + limit); // advance the input stream
115-
// reduce
116-
return acc.reduceLanes(ADD);
25+
public boolean hasNativeAccess() {
26+
return false; // This class does not support native access
11727
}
11828

119-
private int dotProductBody256(byte[] q, int limit) throws IOException {
120-
IntVector acc = IntVector.zero(INT_SPECIES_256);
121-
long offset = in.getFilePointer();
122-
for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
123-
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
124-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
125-
126-
// 32-bit multiply and add into accumulator
127-
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
128-
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
129-
acc = acc.add(va32.mul(vb32));
130-
}
131-
in.seek(offset + limit);
132-
// reduce
133-
return acc.reduceLanes(ADD);
134-
}
135-
136-
private int dotProductBody128(byte[] q, int limit) throws IOException {
137-
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
138-
long offset = in.getFilePointer();
139-
// 4 bytes at a time (re-loading half the vector each time!)
140-
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
141-
// load 8 bytes
142-
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
143-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
144-
145-
// process first "half" only: 16-bit multiply
146-
Vector<Short> va16 = va8.convert(B2S, 0);
147-
Vector<Short> vb16 = vb8.convert(B2S, 0);
148-
Vector<Short> prod16 = va16.mul(vb16);
149-
150-
// 32-bit add
151-
acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
152-
}
153-
in.seek(offset + limit);
154-
// reduce
155-
return acc.reduceLanes(ADD);
29+
@Override
30+
public long int7DotProduct(byte[] q) throws IOException {
31+
return panamaInt7DotProduct(q);
15632
}
15733

15834
@Override
15935
public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
160-
assert dimensions == q.length;
161-
// only vectorize if we'll at least enter the loop a single time
162-
if (dimensions >= 16) {
163-
// compute vectorized dot product consistent with VPDPBUSD instruction
164-
if (VECTOR_BITSIZE >= 512) {
165-
dotProductBody512Bulk(q, count, scores);
166-
} else if (VECTOR_BITSIZE == 256) {
167-
dotProductBody256Bulk(q, count, scores);
168-
} else {
169-
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
170-
dotProductBody128Bulk(q, count, scores);
171-
}
172-
} else {
173-
int7DotProductBulk(q, count, scores);
174-
}
175-
}
176-
177-
private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException {
178-
int limit = BYTE_SPECIES_128.loopBound(dimensions);
179-
for (int iter = 0; iter < count; iter++) {
180-
IntVector acc = IntVector.zero(INT_SPECIES_512);
181-
long offset = in.getFilePointer();
182-
int i = 0;
183-
for (; i < limit; i += BYTE_SPECIES_128.length()) {
184-
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
185-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
186-
187-
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
188-
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
189-
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
190-
Vector<Short> prod16 = va16.mul(vb16);
191-
192-
// 32-bit add
193-
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
194-
acc = acc.add(prod32);
195-
}
196-
197-
in.seek(offset + limit); // advance the input stream
198-
// reduce
199-
long res = acc.reduceLanes(ADD);
200-
for (; i < dimensions; i++) {
201-
res += in.readByte() * q[i];
202-
}
203-
scores[iter] = res;
204-
}
205-
}
206-
207-
private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException {
208-
int limit = BYTE_SPECIES_128.loopBound(dimensions);
209-
for (int iter = 0; iter < count; iter++) {
210-
IntVector acc = IntVector.zero(INT_SPECIES_256);
211-
long offset = in.getFilePointer();
212-
int i = 0;
213-
for (; i < limit; i += BYTE_SPECIES_64.length()) {
214-
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
215-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
216-
217-
// 32-bit multiply and add into accumulator
218-
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
219-
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
220-
acc = acc.add(va32.mul(vb32));
221-
}
222-
in.seek(offset + limit);
223-
// reduce
224-
long res = acc.reduceLanes(ADD);
225-
for (; i < dimensions; i++) {
226-
res += in.readByte() * q[i];
227-
}
228-
scores[iter] = res;
229-
}
230-
}
231-
232-
private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException {
233-
int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
234-
for (int iter = 0; iter < count; iter++) {
235-
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
236-
long offset = in.getFilePointer();
237-
// 4 bytes at a time (re-loading half the vector each time!)
238-
int i = 0;
239-
for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
240-
// load 8 bytes
241-
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
242-
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
243-
244-
// process first "half" only: 16-bit multiply
245-
Vector<Short> va16 = va8.convert(B2S, 0);
246-
Vector<Short> vb16 = vb8.convert(B2S, 0);
247-
Vector<Short> prod16 = va16.mul(vb16);
248-
249-
// 32-bit add
250-
acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
251-
}
252-
in.seek(offset + limit);
253-
// reduce
254-
long res = acc.reduceLanes(ADD);
255-
for (; i < dimensions; i++) {
256-
res += in.readByte() * q[i];
257-
}
258-
scores[iter] = res;
259-
}
36+
panamaInt7DotProductBulk(q, count, scores);
26037
}
26138

26239
@Override
@@ -281,72 +58,4 @@ public void scoreBulk(
28158
scores
28259
);
28360
}
284-
285-
private void applyCorrectionsBulk(
286-
float queryLowerInterval,
287-
float queryUpperInterval,
288-
int queryComponentSum,
289-
float queryAdditionalCorrection,
290-
VectorSimilarityFunction similarityFunction,
291-
float centroidDp,
292-
float[] scores
293-
) throws IOException {
294-
int limit = FLOAT_SPECIES.loopBound(BULK_SIZE);
295-
int i = 0;
296-
long offset = in.getFilePointer();
297-
float ay = queryLowerInterval;
298-
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
299-
float y1 = queryComponentSum;
300-
for (; i < limit; i += FLOAT_SPECIES.length()) {
301-
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
302-
var lx = FloatVector.fromMemorySegment(
303-
FLOAT_SPECIES,
304-
memorySegment,
305-
offset + 4 * BULK_SIZE + i * Float.BYTES,
306-
ByteOrder.LITTLE_ENDIAN
307-
).sub(ax).mul(SEVEN_BIT_SCALE);
308-
var targetComponentSums = IntVector.fromMemorySegment(
309-
INT_SPECIES,
310-
memorySegment,
311-
offset + 8 * BULK_SIZE + i * Integer.BYTES,
312-
ByteOrder.LITTLE_ENDIAN
313-
).convert(VectorOperators.I2F, 0);
314-
var additionalCorrections = FloatVector.fromMemorySegment(
315-
FLOAT_SPECIES,
316-
memorySegment,
317-
offset + 12 * BULK_SIZE + i * Float.BYTES,
318-
ByteOrder.LITTLE_ENDIAN
319-
);
320-
var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
321-
// ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
322-
// qcDist;
323-
var res1 = ax.mul(ay).mul(dimensions);
324-
var res2 = lx.mul(ay).mul(targetComponentSums);
325-
var res3 = ax.mul(ly).mul(y1);
326-
var res4 = lx.mul(ly).mul(qcDist);
327-
var res = res1.add(res2).add(res3).add(res4);
328-
// For euclidean, we need to invert the score and apply the additional correction, which is
329-
// assumed to be the squared l2norm of the centroid centered vectors.
330-
if (similarityFunction == EUCLIDEAN) {
331-
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
332-
res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0);
333-
res.intoArray(scores, i);
334-
} else {
335-
// For cosine and max inner product, we need to apply the additional correction, which is
336-
// assumed to be the non-centered dot-product between the vector and the centroid
337-
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
338-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
339-
res.intoArray(scores, i);
340-
// not sure how to do it better
341-
for (int j = 0; j < FLOAT_SPECIES.length(); j++) {
342-
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
343-
}
344-
} else {
345-
res = res.add(1f).mul(0.5f).max(0);
346-
res.intoArray(scores, i);
347-
}
348-
}
349-
}
350-
in.seek(offset + 16L * BULK_SIZE);
351-
}
35261
}

0 commit comments

Comments
 (0)