Skip to content

Commit 775b706

Browse files
committed
simplify
1 parent b7e25b4 commit 775b706

File tree

3 files changed

+87
-166
lines changed

3 files changed

+87
-166
lines changed

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java

Lines changed: 3 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,14 @@
88
*/
99
package org.elasticsearch.simdvec.internal;
1010

11-
import jdk.incubator.vector.FloatVector;
12-
import jdk.incubator.vector.IntVector;
13-
import jdk.incubator.vector.VectorOperators;
14-
1511
import org.apache.lucene.index.VectorSimilarityFunction;
1612
import org.apache.lucene.store.IndexInput;
17-
import org.apache.lucene.util.VectorUtil;
1813

1914
import java.io.IOException;
2015
import java.lang.foreign.MemorySegment;
21-
import java.nio.ByteOrder;
22-
23-
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
24-
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
2516

2617
/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
27-
public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92FallBackInt7VectorsScorer {
18+
public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer {
2819

2920
public MemorySegmentES92Int7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
3021
super(in, dimensions, memorySegment);
@@ -37,12 +28,12 @@ public boolean hasNativeAccess() {
3728

3829
@Override
3930
public long int7DotProduct(byte[] q) throws IOException {
40-
return fallbackInt7DotProduct(q);
31+
return panamaInt7DotProduct(q);
4132
}
4233

4334
@Override
4435
public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
45-
fallbackInt7DotProductBulk(q, count, scores);
36+
panamaInt7DotProductBulk(q, count, scores);
4637
}
4738

4839
@Override
@@ -67,72 +58,4 @@ public void scoreBulk(
6758
scores
6859
);
6960
}
70-
71-
private void applyCorrectionsBulk(
72-
float queryLowerInterval,
73-
float queryUpperInterval,
74-
int queryComponentSum,
75-
float queryAdditionalCorrection,
76-
VectorSimilarityFunction similarityFunction,
77-
float centroidDp,
78-
float[] scores
79-
) throws IOException {
80-
int limit = FLOAT_SPECIES.loopBound(BULK_SIZE);
81-
int i = 0;
82-
long offset = in.getFilePointer();
83-
float ay = queryLowerInterval;
84-
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
85-
float y1 = queryComponentSum;
86-
for (; i < limit; i += FLOAT_SPECIES.length()) {
87-
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
88-
var lx = FloatVector.fromMemorySegment(
89-
FLOAT_SPECIES,
90-
memorySegment,
91-
offset + 4 * BULK_SIZE + i * Float.BYTES,
92-
ByteOrder.LITTLE_ENDIAN
93-
).sub(ax).mul(SEVEN_BIT_SCALE);
94-
var targetComponentSums = IntVector.fromMemorySegment(
95-
INT_SPECIES,
96-
memorySegment,
97-
offset + 8 * BULK_SIZE + i * Integer.BYTES,
98-
ByteOrder.LITTLE_ENDIAN
99-
).convert(VectorOperators.I2F, 0);
100-
var additionalCorrections = FloatVector.fromMemorySegment(
101-
FLOAT_SPECIES,
102-
memorySegment,
103-
offset + 12 * BULK_SIZE + i * Float.BYTES,
104-
ByteOrder.LITTLE_ENDIAN
105-
);
106-
var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
107-
// ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
108-
// qcDist;
109-
var res1 = ax.mul(ay).mul(dimensions);
110-
var res2 = lx.mul(ay).mul(targetComponentSums);
111-
var res3 = ax.mul(ly).mul(y1);
112-
var res4 = lx.mul(ly).mul(qcDist);
113-
var res = res1.add(res2).add(res3).add(res4);
114-
// For euclidean, we need to invert the score and apply the additional correction, which is
115-
// assumed to be the squared l2norm of the centroid centered vectors.
116-
if (similarityFunction == EUCLIDEAN) {
117-
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
118-
res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0);
119-
res.intoArray(scores, i);
120-
} else {
121-
// For cosine and max inner product, we need to apply the additional correction, which is
122-
// assumed to be the non-centered dot-product between the vector and the centroid
123-
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
124-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
125-
res.intoArray(scores, i);
126-
// not sure how to do it better
127-
for (int j = 0; j < FLOAT_SPECIES.length(); j++) {
128-
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
129-
}
130-
} else {
131-
res = res.add(1f).mul(0.5f).max(0);
132-
res.intoArray(scores, i);
133-
}
134-
}
135-
}
136-
in.seek(offset + 16L * BULK_SIZE);
137-
}
13861
}
Lines changed: 81 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,33 @@
99
package org.elasticsearch.simdvec.internal;
1010

1111
import jdk.incubator.vector.ByteVector;
12+
import jdk.incubator.vector.FloatVector;
1213
import jdk.incubator.vector.IntVector;
1314
import jdk.incubator.vector.ShortVector;
1415
import jdk.incubator.vector.Vector;
16+
import jdk.incubator.vector.VectorOperators;
1517
import jdk.incubator.vector.VectorShape;
1618
import jdk.incubator.vector.VectorSpecies;
1719

20+
import org.apache.lucene.index.VectorSimilarityFunction;
1821
import org.apache.lucene.store.IndexInput;
22+
import org.apache.lucene.util.VectorUtil;
1923
import org.elasticsearch.simdvec.ES92Int7VectorsScorer;
2024

2125
import java.io.IOException;
2226
import java.lang.foreign.MemorySegment;
27+
import java.nio.ByteOrder;
2328

2429
import static java.nio.ByteOrder.LITTLE_ENDIAN;
2530
import static jdk.incubator.vector.VectorOperators.ADD;
2631
import static jdk.incubator.vector.VectorOperators.B2I;
2732
import static jdk.incubator.vector.VectorOperators.B2S;
2833
import static jdk.incubator.vector.VectorOperators.S2I;
34+
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
35+
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
2936

3037
/** Panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
31-
abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7VectorsScorer {
38+
abstract class MemorySegmentES92PanamaInt7VectorsScorer extends ES92Int7VectorsScorer {
3239

3340
private static final VectorSpecies<Byte> BYTE_SPECIES_64 = ByteVector.SPECIES_64;
3441
private static final VectorSpecies<Byte> BYTE_SPECIES_128 = ByteVector.SPECIES_128;
@@ -41,8 +48,8 @@ abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7Vector
4148
private static final VectorSpecies<Integer> INT_SPECIES_512 = IntVector.SPECIES_512;
4249

4350
private static final int VECTOR_BITSIZE;
44-
protected static final VectorSpecies<Float> FLOAT_SPECIES;
45-
protected static final VectorSpecies<Integer> INT_SPECIES;
51+
private static final VectorSpecies<Float> FLOAT_SPECIES;
52+
private static final VectorSpecies<Integer> INT_SPECIES;
4653

4754
static {
4855
// default to platform supported bitsize
@@ -53,12 +60,12 @@ abstract class MemorySegmentES92FallBackInt7VectorsScorer extends ES92Int7Vector
5360

5461
protected final MemorySegment memorySegment;
5562

56-
public MemorySegmentES92FallBackInt7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
63+
public MemorySegmentES92PanamaInt7VectorsScorer(IndexInput in, int dimensions, MemorySegment memorySegment) {
5764
super(in, dimensions);
5865
this.memorySegment = memorySegment;
5966
}
6067

61-
protected long fallbackInt7DotProduct(byte[] q) throws IOException {
68+
protected long panamaInt7DotProduct(byte[] q) throws IOException {
6269
assert dimensions == q.length;
6370
int i = 0;
6471
int res = 0;
@@ -147,7 +154,7 @@ private int dotProductBody128(byte[] q, int limit) throws IOException {
147154
return acc.reduceLanes(ADD);
148155
}
149156

150-
protected void fallbackInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
157+
protected void panamaInt7DotProductBulk(byte[] q, int count, float[] scores) throws IOException {
151158
assert dimensions == q.length;
152159
// only vectorize if we'll at least enter the loop a single time
153160
if (dimensions >= 16) {
@@ -249,4 +256,72 @@ private void dotProductBody128Bulk(byte[] q, int count, float[] scores) throws I
249256
scores[iter] = res;
250257
}
251258
}
259+
260+
protected void applyCorrectionsBulk(
261+
float queryLowerInterval,
262+
float queryUpperInterval,
263+
int queryComponentSum,
264+
float queryAdditionalCorrection,
265+
VectorSimilarityFunction similarityFunction,
266+
float centroidDp,
267+
float[] scores
268+
) throws IOException {
269+
int limit = FLOAT_SPECIES.loopBound(BULK_SIZE);
270+
int i = 0;
271+
long offset = in.getFilePointer();
272+
float ay = queryLowerInterval;
273+
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
274+
float y1 = queryComponentSum;
275+
for (; i < limit; i += FLOAT_SPECIES.length()) {
276+
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
277+
var lx = FloatVector.fromMemorySegment(
278+
FLOAT_SPECIES,
279+
memorySegment,
280+
offset + 4 * BULK_SIZE + i * Float.BYTES,
281+
ByteOrder.LITTLE_ENDIAN
282+
).sub(ax).mul(SEVEN_BIT_SCALE);
283+
var targetComponentSums = IntVector.fromMemorySegment(
284+
INT_SPECIES,
285+
memorySegment,
286+
offset + 8 * BULK_SIZE + i * Integer.BYTES,
287+
ByteOrder.LITTLE_ENDIAN
288+
).convert(VectorOperators.I2F, 0);
289+
var additionalCorrections = FloatVector.fromMemorySegment(
290+
FLOAT_SPECIES,
291+
memorySegment,
292+
offset + 12 * BULK_SIZE + i * Float.BYTES,
293+
ByteOrder.LITTLE_ENDIAN
294+
);
295+
var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
296+
// ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
297+
// qcDist;
298+
var res1 = ax.mul(ay).mul(dimensions);
299+
var res2 = lx.mul(ay).mul(targetComponentSums);
300+
var res3 = ax.mul(ly).mul(y1);
301+
var res4 = lx.mul(ly).mul(qcDist);
302+
var res = res1.add(res2).add(res3).add(res4);
303+
// For euclidean, we need to invert the score and apply the additional correction, which is
304+
// assumed to be the squared l2norm of the centroid centered vectors.
305+
if (similarityFunction == EUCLIDEAN) {
306+
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
307+
res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0);
308+
res.intoArray(scores, i);
309+
} else {
310+
// For cosine and max inner product, we need to apply the additional correction, which is
311+
// assumed to be the non-centered dot-product between the vector and the centroid
312+
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
313+
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
314+
res.intoArray(scores, i);
315+
// not sure how to do it better
316+
for (int j = 0; j < FLOAT_SPECIES.length(); j++) {
317+
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
318+
}
319+
} else {
320+
res = res.add(1f).mul(0.5f).max(0);
321+
res.intoArray(scores, i);
322+
}
323+
}
324+
}
325+
in.seek(offset + 16L * BULK_SIZE);
326+
}
252327
}

libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/MemorySegmentES92Int7VectorsScorer.java

Lines changed: 3 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,15 @@
88
*/
99
package org.elasticsearch.simdvec.internal;
1010

11-
import jdk.incubator.vector.FloatVector;
12-
import jdk.incubator.vector.IntVector;
13-
import jdk.incubator.vector.VectorOperators;
14-
1511
import org.apache.lucene.index.VectorSimilarityFunction;
1612
import org.apache.lucene.store.IndexInput;
17-
import org.apache.lucene.util.VectorUtil;
1813
import org.elasticsearch.nativeaccess.NativeAccess;
1914

2015
import java.io.IOException;
2116
import java.lang.foreign.MemorySegment;
22-
import java.nio.ByteOrder;
23-
24-
import static org.apache.lucene.index.VectorSimilarityFunction.EUCLIDEAN;
25-
import static org.apache.lucene.index.VectorSimilarityFunction.MAXIMUM_INNER_PRODUCT;
2617

2718
/** Native / panamized scorer for 7-bit quantized vectors stored as an {@link IndexInput}. **/
28-
public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92FallBackInt7VectorsScorer {
19+
public final class MemorySegmentES92Int7VectorsScorer extends MemorySegmentES92PanamaInt7VectorsScorer {
2920

3021
private static final boolean NATIVE_SUPPORTED = NativeAccess.instance().getVectorSimilarityFunctions().isPresent();
3122

@@ -44,7 +35,7 @@ public long int7DotProduct(byte[] q) throws IOException {
4435
if (NATIVE_SUPPORTED) {
4536
return nativeInt7DotProduct(q);
4637
} else {
47-
return fallbackInt7DotProduct(q);
38+
return panamaInt7DotProduct(q);
4839
}
4940

5041
}
@@ -66,7 +57,7 @@ public void int7DotProductBulk(byte[] q, int count, float[] scores) throws IOExc
6657
scores[i] = nativeInt7DotProduct(q);
6758
}
6859
} else {
69-
fallbackInt7DotProductBulk(q, count, scores);
60+
panamaInt7DotProductBulk(q, count, scores);
7061
}
7162
}
7263

@@ -92,72 +83,4 @@ public void scoreBulk(
9283
scores
9384
);
9485
}
95-
96-
private void applyCorrectionsBulk(
97-
float queryLowerInterval,
98-
float queryUpperInterval,
99-
int queryComponentSum,
100-
float queryAdditionalCorrection,
101-
VectorSimilarityFunction similarityFunction,
102-
float centroidDp,
103-
float[] scores
104-
) throws IOException {
105-
int limit = FLOAT_SPECIES.loopBound(BULK_SIZE);
106-
int i = 0;
107-
long offset = in.getFilePointer();
108-
float ay = queryLowerInterval;
109-
float ly = (queryUpperInterval - ay) * SEVEN_BIT_SCALE;
110-
float y1 = queryComponentSum;
111-
for (; i < limit; i += FLOAT_SPECIES.length()) {
112-
var ax = FloatVector.fromMemorySegment(FLOAT_SPECIES, memorySegment, offset + i * Float.BYTES, ByteOrder.LITTLE_ENDIAN);
113-
var lx = FloatVector.fromMemorySegment(
114-
FLOAT_SPECIES,
115-
memorySegment,
116-
offset + 4 * BULK_SIZE + i * Float.BYTES,
117-
ByteOrder.LITTLE_ENDIAN
118-
).sub(ax).mul(SEVEN_BIT_SCALE);
119-
var targetComponentSums = IntVector.fromMemorySegment(
120-
INT_SPECIES,
121-
memorySegment,
122-
offset + 8 * BULK_SIZE + i * Integer.BYTES,
123-
ByteOrder.LITTLE_ENDIAN
124-
).convert(VectorOperators.I2F, 0);
125-
var additionalCorrections = FloatVector.fromMemorySegment(
126-
FLOAT_SPECIES,
127-
memorySegment,
128-
offset + 12 * BULK_SIZE + i * Float.BYTES,
129-
ByteOrder.LITTLE_ENDIAN
130-
);
131-
var qcDist = FloatVector.fromArray(FLOAT_SPECIES, scores, i);
132-
// ax * ay * dimensions + ay * lx * (float) targetComponentSum + ax * ly * y1 + lx * ly *
133-
// qcDist;
134-
var res1 = ax.mul(ay).mul(dimensions);
135-
var res2 = lx.mul(ay).mul(targetComponentSums);
136-
var res3 = ax.mul(ly).mul(y1);
137-
var res4 = lx.mul(ly).mul(qcDist);
138-
var res = res1.add(res2).add(res3).add(res4);
139-
// For euclidean, we need to invert the score and apply the additional correction, which is
140-
// assumed to be the squared l2norm of the centroid centered vectors.
141-
if (similarityFunction == EUCLIDEAN) {
142-
res = res.mul(-2).add(additionalCorrections).add(queryAdditionalCorrection).add(1f);
143-
res = FloatVector.broadcast(FLOAT_SPECIES, 1).div(res).max(0);
144-
res.intoArray(scores, i);
145-
} else {
146-
// For cosine and max inner product, we need to apply the additional correction, which is
147-
// assumed to be the non-centered dot-product between the vector and the centroid
148-
res = res.add(queryAdditionalCorrection).add(additionalCorrections).sub(centroidDp);
149-
if (similarityFunction == MAXIMUM_INNER_PRODUCT) {
150-
res.intoArray(scores, i);
151-
// not sure how to do it better
152-
for (int j = 0; j < FLOAT_SPECIES.length(); j++) {
153-
scores[i + j] = VectorUtil.scaleMaxInnerProductScore(scores[i + j]);
154-
}
155-
} else {
156-
res = res.add(1f).mul(0.5f).max(0);
157-
res.intoArray(scores, i);
158-
}
159-
}
160-
}
161-
in.seek(offset + 16L * BULK_SIZE);
162-
}
16386
}

0 commit comments

Comments
 (0)