Skip to content

Commit b7e25b4

Browse files
committed
Use panamized version for windows in Int7VectorScorer
1 parent 7d1f135 commit b7e25b4

File tree

5 files changed

+308
-241
lines changed

5 files changed

+308
-241
lines changed

libs/simdvec/src/main/java/org/elasticsearch/simdvec/ES92Int7VectorsScorer.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ public ES92Int7VectorsScorer(IndexInput in, int dimensions) {
4141
this.dimensions = dimensions;
4242
}
4343

44+
/**
45+
* Checks if the current implementation supports fast native access.
46+
*/
47+
public boolean hasNativeAccess() {
48+
return false; // This class does not support native access
49+
}
50+
4451
/**
4552
* compute the quantize distance between the provided quantized query and the quantized vector
4653
* that is read from the wrapped {@link IndexInput}.
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the "Elastic License
4+
* 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side
5+
* Public License v 1"; you may not use this file except in compliance with, at
6+
* your election, the "Elastic License 2.0", the "GNU Affero General Public
7+
* License v3.0 only", or the "Server Side Public License, v 1".
8+
*/
9+
package org.elasticsearch.simdvec.internal;
10+
11+
import jdk.incubator.vector.ByteVector;
12+
import jdk.incubator.vector.IntVector;
13+
import jdk.incubator.vector.ShortVector;
14+
import jdk.incubator.vector.Vector;
15+
import jdk.incubator.vector.VectorShape;
16+
import jdk.incubator.vector.VectorSpecies;
17+
18+
import org.apache.lucene.store.IndexInput;
19+
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
20+
21+
import java.io.IOException;
22+
import java.lang.foreign.MemorySegment;
23+
24+
import static java.nio.ByteOrder.LITTLE_ENDIAN;
25+
import static jdk.incubator.vector.VectorOperators.ADD;
26+
import static jdk.incubator.vector.VectorOperators.B2I;
27+
import static jdk.incubator.vector.VectorOperators.B2S;
28+
import static jdk.incubator.vector.VectorOperators.S2I;
29+
30+
/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
31+
abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7VectorsScorer {
32+
33+
private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
34+
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
35+
36+
private static final VectorSpecies<Short> SHORT_SPECIES_128 = ShortVector.SPECIES_128;
37+
private static final VectorSpecies<Short> SHORT_SPECIES_256 = ShortVector.SPECIES_256;
38+
39+
private static final VectorSpecies<Integer> INT_SPECIES_128 = IntVector.SPECIES_128;
40+
private static final VectorSpecies<Integer> INT_SPECIES_256 = IntVector.SPECIES_256;
41+
private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
42+
43+
private static final int VECTOR_BITSIZE;
44+
protected static final VectorSpecies<Float> FLOAT_SPECIES;
45+
protected static final VectorSpecies<Integer> INT_SPECIES;
46+
47+
static {
48+
// default to platform supported bitsize
49+
VECTOR_BITSIZE = VectorShape.preferredShape().vectorBitSize();
50+
FLOAT_SPECIES = VectorSpecies.of(float.class, VectorShape.forBitSize(VECTOR_BITSIZE));
51+
INT_SPECIES = VectorSpecies.of(int.class, VectorShape.forBitSize(VECTOR_BITSIZE));
52+
}
53+
54+
protected final MemorySegment memorySegment;
55+
56+
public MemorySegmentES92FallBackInt7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
57+
super(in, dimensions);
58+
this.memorySegment = memorySegment;
59+
}
60+
61+
protected long fallbackInt7DotProduct(byte[] q) throws IOException {
62+
assert dimensions == q.length;
63+
int i = 0;
64+
int res = 0;
65+
// only vectorize if we'll at least enter the loop a single time
66+
if (dimensions >= 16) {
67+
// compute vectorized dot product consistent with VPDPBUSD instruction
68+
if (VECTOR_BITSIZE >= 512) {
69+
i += BYTE_SPECIES_128.loopBound(dimensions);
70+
res += dotProductBody512(q, i);
71+
} else if (VECTOR_BITSIZE == 256) {
72+
i += BYTE_SPECIES_64.loopBound(dimensions);
73+
res += dotProductBody256(q, i);
74+
} else {
75+
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
76+
i += BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
77+
res += dotProductBody128(q, i);
78+
}
79+
// scalar tail
80+
while (i < dimensions) {
81+
res += in.readByte() * q[i++];
82+
}
83+
return res;
84+
} else {
85+
return super.int7DotProduct(q);
86+
}
87+
}
88+
89+
private int dotProductBody512(byte[] q, int limit) throws IOException {
90+
IntVector acc = IntVector.zero(INT_SPECIES_512);
91+
long offset = in.getFilePointer();
92+
for (int i = 0; i < limit; i += BYTE_SPECIES_128.length()) {
93+
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
94+
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
95+
96+
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
97+
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
98+
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
99+
Vector<Short> prod16 = va16.mul(vb16);
100+
101+
// 32-bit add
102+
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
103+
acc = acc.add(prod32);
104+
}
105+
106+
in.seek(offset + limit); // advance the input stream
107+
// reduce
108+
return acc.reduceLanes(ADD);
109+
}
110+
111+
private int dotProductBody256(byte[] q, int limit) throws IOException {
112+
IntVector acc = IntVector.zero(INT_SPECIES_256);
113+
long offset = in.getFilePointer();
114+
for (int i = 0; i < limit; i += BYTE_SPECIES_64.length()) {
115+
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
116+
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
117+
118+
// 32-bit multiply and add into accumulator
119+
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
120+
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
121+
acc = acc.add(va32.mul(vb32));
122+
}
123+
in.seek(offset + limit);
124+
// reduce
125+
return acc.reduceLanes(ADD);
126+
}
127+
128+
private int dotProductBody128(byte[] q, int limit) throws IOException {
129+
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
130+
long offset = in.getFilePointer();
131+
// 4 bytes at a time (re-loading half the vector each time!)
132+
for (int i = 0; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
133+
// load 8 bytes
134+
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
135+
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
136+
137+
// process first "half" only: 16-bit multiply
138+
Vector<Short> va16 = va8.convert(B2S, 0);
139+
Vector<Short> vb16 = vb8.convert(B2S, 0);
140+
Vector<Short> prod16 = va16.mul(vb16);
141+
142+
// 32-bit add
143+
acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
144+
}
145+
in.seek(offset + limit);
146+
// reduce
147+
return acc.reduceLanes(ADD);
148+
}
149+
150+
protected void fallbackInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
151+
assert dimensions == q.length;
152+
// only vectorize if we'll at least enter the loop a single time
153+
if (dimensions >= 16) {
154+
// compute vectorized dot product consistent with VPDPBUSD instruction
155+
if (VECTOR_BITSIZE >= 512) {
156+
dotProductBody512Bulk(q, count, scores);
157+
} else if (VECTOR_BITSIZE == 256) {
158+
dotProductBody256Bulk(q, count, scores);
159+
} else {
160+
// tricky: we don't have SPECIES_32, so we workaround with "overlapping read"
161+
dotProductBody128Bulk(q, count, scores);
162+
}
163+
} else {
164+
int7DotProductBulk(q, count, scores);
165+
}
166+
}
167+
168+
private void dotProductBody512Bulk(byte[] q, int count, float[] scores) throws IOException {
169+
int limit = BYTE_SPECIES_128.loopBound(dimensions);
170+
for (int iter = 0; iter < count; iter++) {
171+
IntVector acc = IntVector.zero(INT_SPECIES_512);
172+
long offset = in.getFilePointer();
173+
int i = 0;
174+
for (; i < limit; i += BYTE_SPECIES_128.length()) {
175+
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_128, q, i);
176+
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_128, memorySegment, offset + i, LITTLE_ENDIAN);
177+
178+
// 16-bit multiply: avoid AVX-512 heavy multiply on zmm
179+
Vector<Short> va16 = va8.convertShape(B2S, SHORT_SPECIES_256, 0);
180+
Vector<Short> vb16 = vb8.convertShape(B2S, SHORT_SPECIES_256, 0);
181+
Vector<Short> prod16 = va16.mul(vb16);
182+
183+
// 32-bit add
184+
Vector<Integer> prod32 = prod16.convertShape(S2I, INT_SPECIES_512, 0);
185+
acc = acc.add(prod32);
186+
}
187+
188+
in.seek(offset + limit); // advance the input stream
189+
// reduce
190+
long res = acc.reduceLanes(ADD);
191+
for (; i < dimensions; i++) {
192+
res += in.readByte() * q[i];
193+
}
194+
scores[iter] = res;
195+
}
196+
}
197+
198+
private void dotProductBody256Bulk(byte[] q, int count, float[] scores) throws IOException {
199+
int limit = BYTE_SPECIES_128.loopBound(dimensions);
200+
for (int iter = 0; iter < count; iter++) {
201+
IntVector acc = IntVector.zero(INT_SPECIES_256);
202+
long offset = in.getFilePointer();
203+
int i = 0;
204+
for (; i < limit; i += BYTE_SPECIES_64.length()) {
205+
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
206+
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
207+
208+
// 32-bit multiply and add into accumulator
209+
Vector<Integer> va32 = va8.convertShape(B2I, INT_SPECIES_256, 0);
210+
Vector<Integer> vb32 = vb8.convertShape(B2I, INT_SPECIES_256, 0);
211+
acc = acc.add(va32.mul(vb32));
212+
}
213+
in.seek(offset + limit);
214+
// reduce
215+
long res = acc.reduceLanes(ADD);
216+
for (; i < dimensions; i++) {
217+
res += in.readByte() * q[i];
218+
}
219+
scores[iter] = res;
220+
}
221+
}
222+
223+
private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws IOException {
224+
int limit = BYTE_SPECIES_64.loopBound(dimensions - BYTE_SPECIES_64.length());
225+
for (int iter = 0; iter < count; iter++) {
226+
IntVector acc = IntVector.zero(IntVector.SPECIES_128);
227+
long offset = in.getFilePointer();
228+
// 4 bytes at a time (re-loading half the vector each time!)
229+
int i = 0;
230+
for (; i < limit; i += ByteVector.SPECIES_64.length() >> 1) {
231+
// load 8 bytes
232+
ByteVector va8 = ByteVector.fromArray(BYTE_SPECIES_64, q, i);
233+
ByteVector vb8 = ByteVector.fromMemorySegment(BYTE_SPECIES_64, memorySegment, offset + i, LITTLE_ENDIAN);
234+
235+
// process first "half" only: 16-bit multiply
236+
Vector<Short> va16 = va8.convert(B2S, 0);
237+
Vector<Short> vb16 = vb8.convert(B2S, 0);
238+
Vector<Short> prod16 = va16.mul(vb16);
239+
240+
// 32-bit add
241+
acc = acc.add(prod16.convertShape(S2I, IntVector.SPECIES_128, 0));
242+
}
243+
in.seek(offset + limit);
244+
// reduce
245+
long res = acc.reduceLanes(ADD);
246+
for (; i < dimensions; i++) {
247+
res += in.readByte() * q[i];
248+
}
249+
scores[iter] = res;
250+
}
251+
}
252+
}

0 commit comments

Comments
 (0)