Skip to content

Commit 39d7f79

Browse files
committed
Remove score correction from native signature
1 parent f232613 commit 39d7f79

File tree

9 files changed

+15
-68
lines changed

9 files changed

+15
-68
lines changed

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
6262
);
6363
dot7uBulkWithOffsets$mh = downcallHandle(
6464
"vec_dot7u_bulk_offsets_2",
65-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, JAVA_FLOAT, ADDRESS),
65+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
6666
LinkerHelperUtil.critical()
6767
);
6868
sqr7u$mh = downcallHandle(
@@ -98,7 +98,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
9898
);
9999
dot7uBulkWithOffsets$mh = downcallHandle(
100100
"vec_dot7u_bulk_offsets",
101-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, JAVA_FLOAT, ADDRESS),
101+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
102102
LinkerHelperUtil.critical()
103103
);
104104
sqr7u$mh = downcallHandle(
@@ -180,10 +180,9 @@ static void dotProduct7uBulkWithOffsets(
180180
int pitch,
181181
MemorySegment offsets,
182182
int count,
183-
float scoreCorrection,
184183
MemorySegment result
185184
) {
186-
dot7uBulkWithOffsets(a, b, length, pitch, offsets, count, scoreCorrection, result);
185+
dot7uBulkWithOffsets(a, b, length, pitch, offsets, count, result);
187186
}
188187

189188
/**
@@ -269,11 +268,10 @@ private static void dot7uBulkWithOffsets(
269268
int pitch,
270269
MemorySegment offsets,
271270
int count,
272-
float scoreCorrection,
273271
MemorySegment result
274272
) {
275273
try {
276-
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, pitch, offsets, count, scoreCorrection, result);
274+
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, pitch, offsets, count, result);
277275
} catch (Throwable t) {
278276
throw new AssertionError(t);
279277
}
@@ -340,7 +338,6 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
340338
int.class,
341339
MemorySegment.class,
342340
int.class,
343-
float.class,
344341
MemorySegment.class
345342
)
346343
);

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ public void testInt7uBulkWithOffsets() {
140140
var nativeQuerySeg = vectorsSegment.asSlice((long) queryOrd * dims, dims);
141141
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
142142

143-
dotProduct7uBulkWithOffsets(vectorsSegment, nativeQuerySeg, dims, dims, offsetsSegment, numVecs, 1.0f, bulkScoresSeg);
143+
dotProduct7uBulkWithOffsets(vectorsSegment, nativeQuerySeg, dims, dims, offsetsSegment, numVecs, bulkScoresSeg);
144144
assertScoresEquals(expectedScores, bulkScoresSeg);
145145
}
146146

@@ -168,7 +168,7 @@ public void testInt7uBulkWithOffsetsAndPitch() {
168168
var nativeQuerySeg = vectorsSegment.asSlice((long) queryOrd * pitch, pitch);
169169
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
170170

171-
dotProduct7uBulkWithOffsets(vectorsSegment, nativeQuerySeg, dims, pitch, offsetsSegment, numVecs, 1.0f, bulkScoresSeg);
171+
dotProduct7uBulkWithOffsets(vectorsSegment, nativeQuerySeg, dims, pitch, offsetsSegment, numVecs, bulkScoresSeg);
172172
assertScoresEquals(expectedScores, bulkScoresSeg);
173173
}
174174

@@ -199,7 +199,6 @@ public void testInt7uBulkWithOffsetsHeapSegments() {
199199
dims,
200200
MemorySegment.ofArray(offsets),
201201
numVecs,
202-
1.0f,
203202
MemorySegment.ofArray(bulkScores)
204203
);
205204
assertArrayEquals(expectedScores, bulkScores, 0f);
@@ -297,11 +296,10 @@ void dotProduct7uBulkWithOffsets(
297296
int pitch,
298297
MemorySegment offsets,
299298
int count,
300-
float scoreCorrection,
301299
MemorySegment result
302300
) {
303301
try {
304-
getVectorDistance().dotProductHandle7uBulkWithOffsets().invokeExact(a, b, dims, pitch, offsets, count, scoreCorrection, result);
302+
getVectorDistance().dotProductHandle7uBulkWithOffsets().invokeExact(a, b, dims, pitch, offsets, count, result);
305303
} catch (Throwable e) {
306304
if (e instanceof Error err) {
307305
throw err;

libs/simdvec/native/src/vec/c/aarch64/vec_1.cpp

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,6 @@ EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
100100
return res;
101101
}
102102

103-
static inline f32_t adjust(f32_t raw_score, f32_t score_correction, f32_t first_offset, f32_t second_offset) {
104-
auto adjusted_score = raw_score * score_correction + first_offset + second_offset;
105-
return fmaxf((1.0f + adjusted_score) / 2.0f, 0.0f);
106-
}
107-
108-
static inline f32_t int_bits_to_float(const int32_t v) {
109-
union {
110-
int i;
111-
float f;
112-
} u;
113-
u.i = (long)v;
114-
return (f32_t)u.f;
115-
}
116-
117103
template <int64_t(*mapper)(const int32_t, const int32_t*)>
118104
static inline void dot7u_inner_bulk(
119105
const int8_t* a,
@@ -122,7 +108,6 @@ static inline void dot7u_inner_bulk(
122108
const int32_t pitch,
123109
const int32_t* offsets,
124110
const int32_t count,
125-
const f32_t score_correction,
126111
f32_t* results
127112
) {
128113
size_t blk = dims & ~15;
@@ -215,7 +200,7 @@ static inline int64_t index(const int32_t i, const int32_t* offsets) {
215200
}
216201

217202
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
218-
dot7u_inner_bulk<identity>(a, b, dims, dims, NULL, count, 1.0f, results);
203+
dot7u_inner_bulk<identity>(a, b, dims, dims, NULL, count, results);
219204
}
220205

221206

@@ -226,9 +211,8 @@ EXPORT void vec_dot7u_bulk_offsets(
226211
const int32_t pitch,
227212
const int32_t* offsets,
228213
const int32_t count,
229-
const f32_t score_correction,
230214
f32_t* results) {
231-
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, score_correction, results);
215+
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, results);
232216
}
233217

234218
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {

libs/simdvec/native/src/vec/c/amd64/vec_1.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ static inline void dot7u_inner_bulk(
165165
const int32_t pitch,
166166
const int32_t* offsets,
167167
const int32_t count,
168-
const f32_t score_correction, // TODO
169168
f32_t* results
170169
) {
171170
if (dims > STRIDE_BYTES_LEN) {
@@ -211,9 +210,8 @@ EXPORT void vec_dot7u_bulk_offsets(
211210
const int32_t pitch,
212211
const int32_t* offsets,
213212
const int32_t count,
214-
const f32_t score_correction,
215213
f32_t* results) {
216-
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, score_correction, results);
214+
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, results);
217215
}
218216

219217
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {

libs/simdvec/native/src/vec/c/amd64/vec_2.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@ static inline void dot7u_inner_bulk(
131131
const int32_t pitch,
132132
const int32_t* offsets,
133133
const int32_t count,
134-
const f32_t score_correction, // TODO
135134
f32_t* results
136135
) {
137136
if (dims > STRIDE_BYTES_LEN) {
@@ -177,9 +176,8 @@ EXPORT void vec_dot7u_bulk_offsets_2(
177176
const int32_t pitch,
178177
const int32_t* offsets,
179178
const int32_t count,
180-
const f32_t score_correction,
181179
f32_t* results) {
182-
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, score_correction, results);
180+
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, results);
183181
}
184182

185183
template<int offsetRegs>

libs/simdvec/native/src/vec/headers/vec.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ EXPORT void vec_dot7u_bulk_offsets(
4848
const int32_t pitch,
4949
const int32_t* offsets,
5050
const int32_t count,
51-
const f32_t score_correction,
5251
f32_t* results);
5352

5453
EXPORT int32_t vec_sqr7u(int8_t *a, int8_t *b, const int32_t length);

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorerSupplier.java

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -69,15 +69,7 @@ final void bulkScoreFromOrds(int firstOrd, int[] ordinals, float[] scores, int n
6969
if (SUPPORTS_HEAP_SEGMENTS) {
7070
var ordinalsSeg = MemorySegment.ofArray(ordinals);
7171
var scoresSeg = MemorySegment.ofArray(scores);
72-
bulkScoreFromSegment(
73-
vectorsSeg,
74-
vectorLength,
75-
vectorPitch,
76-
firstOrd,
77-
ordinalsSeg,
78-
scoresSeg,
79-
numNodes
80-
);
72+
bulkScoreFromSegment(vectorsSeg, vectorLength, vectorPitch, firstOrd, ordinalsSeg, scoresSeg, numNodes);
8173
} else {
8274
try (var arena = Arena.ofConfined()) {
8375
var ordinalsMemorySegment = arena.allocate((long) numNodes * Integer.BYTES, 32);
@@ -251,16 +243,7 @@ protected void bulkScoreFromSegment(
251243
) {
252244
long firstByteOffset = (long) firstOrd * vectorPitch;
253245
var firstVector = vectors.asSlice(firstByteOffset, vectorPitch);
254-
Similarities.dotProduct7uBulkWithOffsets(
255-
vectors,
256-
firstVector,
257-
dims,
258-
vectorPitch,
259-
ordinals,
260-
numNodes,
261-
scoreCorrectionConstant,
262-
scores
263-
);
246+
Similarities.dotProduct7uBulkWithOffsets(vectors, firstVector, dims, vectorPitch, ordinals, numNodes, scores);
264247

265248
// Java-side adjustment
266249
var aOffset = Float.intBitsToFloat(vectors.asSlice(firstByteOffset + vectorLength, Float.BYTES).get(ValueLayout.JAVA_INT, 0));

libs/simdvec/src/main21/java/org/elasticsearch/simdvec/internal/Similarities.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,10 @@ static void dotProduct7uBulkWithOffsets(
7575
int pitch,
7676
MemorySegment offsets,
7777
int count,
78-
float scoreCorrection,
7978
MemorySegment scores
8079
) {
8180
try {
82-
DOT_HANDLE_7U_BULK_WITH_OFFSETS.invokeExact(a, b, length, pitch, offsets, count, scoreCorrection, scores);
81+
DOT_HANDLE_7U_BULK_WITH_OFFSETS.invokeExact(a, b, length, pitch, offsets, count, scores);
8382
} catch (Throwable e) {
8483
if (e instanceof Error err) {
8584
throw err;

libs/simdvec/src/main22/java/org/elasticsearch/simdvec/internal/Int7SQVectorScorer.java

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,7 @@ public void bulkScore(int[] nodes, float[] scores, int numNodes) throws IOExcept
122122
var scoresSeg = MemorySegment.ofArray(scores);
123123

124124
var vectorPitch = vectorByteSize + Float.BYTES;
125-
dotProduct7uBulkWithOffsets(
126-
vectorsSeg,
127-
query,
128-
vectorByteSize,
129-
vectorPitch,
130-
ordinalsSeg,
131-
numNodes,
132-
scoreCorrectionConstant,
133-
scoresSeg
134-
);
125+
dotProduct7uBulkWithOffsets(vectorsSeg, query, vectorByteSize, vectorPitch, ordinalsSeg, numNodes, scoresSeg);
135126

136127
for (int i = 0; i < numNodes; ++i) {
137128
var dotProduct = scores[i];

0 commit comments

Comments
 (0)