Skip to content

Commit f232613

Browse files
committed
Add implementation for Java22 Int7SQVectorScorer
1 parent 451b924 commit f232613

File tree

3 files changed

+57
-34
lines changed

3 files changed

+57
-34
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/vector/scorer/VectorScorerInt7uBulkBenchmark.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,10 @@ public class VectorScorerInt7uBulkBenchmark {
8282

8383
// 128k is typically enough to not fit in L1 (core) cache for most processors;
8484
// 1.5M is typically enough to not fit in L2 (core) cache;
85-
// 40M is typically enough to not fit in L3 cache
86-
@Param({ "128000", "1500000", "30000000" })
85+
// 130M is enough to not fit in L3 cache
86+
@Param({ "128", "1500", "130000" })
8787
public int numVectors;
88-
public int numVectorsToScore = 20_000;
88+
public int numVectorsToScore;
8989

9090
Path path;
9191
Directory dir;
@@ -107,6 +107,7 @@ public class VectorScorerInt7uBulkBenchmark {
107107

108108
@Setup(Level.Trial)
109109
public void setup() throws IOException {
110+
numVectorsToScore = Math.min(numVectors, 20_000);
110111
factory = getScorerFactoryOrDie();
111112

112113
var random = ThreadLocalRandom.current();

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,8 @@ protected final void checkOrdinal(int ord) {
6060
}
6161

6262
final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int numNodes) throws IOException {
63-
// throw new UnsupportedOperationException("not implemented");
64-
65-
MemorySegment segment = input.segmentSliceOrNull(0, input.length());
66-
if (segment == null) {
63+
MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length());
64+
if (vectorsSeg == null) {
6765
bulkFallbackScore(firstOrd, ordinals, scores, numNodes);
6866
} else {
6967
final int vectorLength = dims;
@@ -72,7 +70,7 @@ final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int n
7270
var ordinalsSeg = MemorySegment.ofArray(ordinals);
7371
var scoresSeg = MemorySegment.ofArray(scores);
7472
bulkScoreFromSegment(
75-
segment,
73+
vectorsSeg,
7674
vectorLength,
7775
vectorPitch,
7876
firstOrd,
@@ -87,7 +85,7 @@ final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int n
8785
MemorySegment.copy(ordinals, 0, ordinalsMemorySegment, ValueLayout.JAVA_INT, 0, numNodes);
8886

8987
bulkScoreFromSegment(
90-
segment,
88+
vectorsSeg,
9189
vectorLength,
9290
vectorPitch,
9391
firstOrd,
@@ -99,17 +97,6 @@ final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int n
9997
MemorySegment.copy(scoresMemorySegment, ValueLayout.JAVA_FLOAT, 0, scores, 0, numNodes);
10098
}
10199
}
102-
103-
long firstByteOffset = (long) firstOrd * vectorPitch;
104-
var aOffset = Float.intBitsToFloat(input.readInt(firstByteOffset + vectorLength));
105-
for (int i = 0; i < numNodes; ++i) {
106-
var dotProduct = scores[i];
107-
var secondOrd = ordinals[i];
108-
long secondByteOffset = (long) secondOrd * vectorPitch;
109-
var bOffset = Float.intBitsToFloat(input.readInt(secondByteOffset + vectorLength));
110-
float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset;
111-
scores[i] = Math.max((1 + adjustedDistance) / 2, 0f);
112-
}
113100
}
114101
}
115102

@@ -149,12 +136,14 @@ protected void bulkScoreFromSegment(
149136
var a = vectors.asSlice(firstByteOffset, vectorLength);
150137
var aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).get(ValueLayout.JAVA_INT, 0));
151138
for (int i = 0; i < numNodes; ++i) {
152-
var secondOrd = ordinals.get(ValueLayout.JAVA_INT, i);
139+
var secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, i);
153140
long secondByteOffset = (long) secondOrd * vectorPitch;
154141
var b = vectors.asSlice(secondByteOffset, vectorLength);
155-
var bOffset = Float.intBitsToFloat(vectors.asSlice(secondByteOffset + vectorLength, Float.BYTES).get(ValueLayout.JAVA_INT, 0));
142+
var bOffset = Float.intBitsToFloat(
143+
vectors.asSlice(secondByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT, 0)
144+
);
156145
var score = scoreFromSegments(a, aOffset, b, bOffset);
157-
scores.set(ValueLayout.JAVA_FLOAT, i, score);
146+
scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, score);
158147
}
159148
}
160149

@@ -274,17 +263,17 @@ protected void bulkScoreFromSegment(
274263
);
275264

276265
// Java-side adjustment
277-
// var aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).get(ValueLayout.JAVA_INT, 0));
278-
// for (int i = 0; i < numNodes; ++i) {
279-
// var dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i);
280-
// var secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, i);
281-
// long secondByteOffset = (long) secondOrd * vectorPitch;
282-
// var bOffset = Float.intBitsToFloat(
283-
// vectors.asSlice(secondByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT, 0)
284-
// );
285-
// float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset;
286-
// scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, Math.max((1 + adjustedDistance) / 2, 0f));
287-
// }
266+
var aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).get(ValueLayout.JAVA_INT, 0));
267+
for (int i = 0; i < numNodes; ++i) {
268+
var dotProduct = scores.getAtIndex(ValueLayout.JAVA_FLOAT, i);
269+
var secondOrd = ordinals.getAtIndex(ValueLayout.JAVA_INT, i);
270+
long secondByteOffset = (long) secondOrd * vectorPitch;
271+
var bOffset = Float.intBitsToFloat(
272+
vectors.asSlice(secondByteOffset + vectorLength, Float.BYTES).getAtIndex(ValueLayout.JAVA_INT, 0)
273+
);
274+
float adjustedDistance = dotProduct * scoreCorrectionConstant + aOffset + bOffset;
275+
scores.setAtIndex(ValueLayout.JAVA_FLOAT, i, Math.max((1 + adjustedDistance) / 2, 0f));
276+
}
288277
}
289278

290279
@Override

libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import java.util.Optional;
2424

2525
import static org.elasticsearch.simdvec.internal.Similarities.dotProduct7u;
26+
import static org.elasticsearch.simdvec.internal.Similarities.dotProduct7uBulkWithOffsets;
2627
import static org.elasticsearch.simdvec.internal.Similarities.squareDistance7u;
2728

2829
public abstract sealed class Int7SQVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer {
@@ -110,6 +111,38 @@ public float score(int node) throws IOException {
110111
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
111112
return Math.max((1 + adjustedDistance) / 2, 0f);
112113
}
114+
115+
@Override
116+
public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOException {
117+
MemorySegment vectorsSeg = input.segmentSliceOrNull(0, input.length());
118+
if (vectorsSeg == null) {
119+
super.bulkScore(nodes, scores, numNodes);
120+
} else {
121+
var ordinalsSeg = MemorySegment.ofArray(nodes);
122+
var scoresSeg = MemorySegment.ofArray(scores);
123+
124+
var vectorPitch = vectorByteSize + Float.BYTES;
125+
dotProduct7uBulkWithOffsets(
126+
vectorsSeg,
127+
query,
128+
vectorByteSize,
129+
vectorPitch,
130+
ordinalsSeg,
131+
numNodes,
132+
scoreCorrectionConstant,
133+
scoresSeg
134+
);
135+
136+
for (int i = 0; i < numNodes; ++i) {
137+
var dotProduct = scores[i];
138+
var secondOrd = nodes[i];
139+
long secondByteOffset = (long) secondOrd * vectorPitch;
140+
var nodeCorrection = Float.intBitsToFloat(input.readInt(secondByteOffset + vectorByteSize));
141+
float adjustedDistance = dotProduct * scoreCorrectionConstant + queryCorrection + nodeCorrection;
142+
scores[i] = Math.max((1 + adjustedDistance) / 2, 0f);
143+
}
144+
}
145+
}
113146
}
114147

115148
public static final class EuclideanScorer extends Int7SQVectorScorer {

0 commit comments

Comments
 (0)