Skip to content

Commit 3369521

Browse files
committed
WIP: scoring on CPP side - fix signature to have pitch
1 parent 639bbf0 commit 3369521

File tree

7 files changed

+259
-29
lines changed

7 files changed

+259
-29
lines changed

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
6262
);
6363
dot7uBulkWithOffsets$mh = downcallHandle(
6464
"vec_dot7u_bulk_offsets_2",
65-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
65+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, JAVA_FLOAT, ADDRESS),
6666
LinkerHelperUtil.critical()
6767
);
6868
sqr7u$mh = downcallHandle(
@@ -177,11 +177,13 @@ static void dotProduct7uBulkWithOffsets(
177177
MemorySegment a,
178178
MemorySegment b,
179179
int length,
180+
int pitch,
180181
MemorySegment offsets,
181182
int count,
183+
float scoreCorrection,
182184
MemorySegment result
183185
) {
184-
dot7uBulkWithOffsets(a, b, length, offsets, count, result);
186+
dot7uBulkWithOffsets(a, b, length, pitch, offsets, count, scoreCorrection, result);
185187
}
186188

187189
/**
@@ -264,12 +266,14 @@ private static void dot7uBulkWithOffsets(
264266
MemorySegment a,
265267
MemorySegment b,
266268
int length,
269+
int pitch,
267270
MemorySegment offsets,
268271
int count,
272+
float scoreCorrection,
269273
MemorySegment result
270274
) {
271275
try {
272-
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, offsets, count, result);
276+
JdkVectorLibrary.dot7uBulkWithOffsets$mh.invokeExact(a, b, length, pitch, offsets, count, scoreCorrection, result);
273277
} catch (Throwable t) {
274278
throw new AssertionError(t);
275279
}

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,8 @@ public void testInt7uBulkWithOffsets() {
137137
var nativeQuerySeg = segment.asSlice((long) queryOrd * dims, dims);
138138
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
139139
var offsetsSeg = arena.allocate((long) numVecs * Integer.BYTES);
140-
dotProduct7uBulkWithOffsets(segment, nativeQuerySeg, dims, offsetsSeg, numVecs, bulkScoresSeg);
140+
// TODO: test pitch
141+
dotProduct7uBulkWithOffsets(segment, nativeQuerySeg, dims, dims, offsetsSeg, numVecs, 1.0f, bulkScoresSeg);
141142
assertScoresEquals(expectedScores, bulkScoresSeg);
142143

143144
if (supportsHeapSegments()) {
@@ -146,8 +147,10 @@ public void testInt7uBulkWithOffsets() {
146147
segment,
147148
nativeQuerySeg,
148149
dims,
150+
dims,
149151
MemorySegment.ofArray(offsets),
150152
numVecs,
153+
1.0f,
151154
MemorySegment.ofArray(bulkScores)
152155
);
153156
assertArrayEquals(expectedScores, bulkScores, 0f);
@@ -239,9 +242,18 @@ void dotProduct7uBulk(MemorySegment a, MemorySegment b, int dims, int count, Mem
239242
}
240243
}
241244

242-
void dotProduct7uBulkWithOffsets(MemorySegment a, MemorySegment b, int dims, MemorySegment offsets, int count, MemorySegment result) {
245+
void dotProduct7uBulkWithOffsets(
246+
MemorySegment a,
247+
MemorySegment b,
248+
int dims,
249+
int pitch,
250+
MemorySegment offsets,
251+
int count,
252+
float scoreCorrection,
253+
MemorySegment result
254+
) {
243255
try {
244-
getVectorDistance().dotProductHandle7uBulkWithOffsets().invokeExact(a, b, dims, count, result);
256+
getVectorDistance().dotProductHandle7uBulkWithOffsets().invokeExact(a, b, dims, pitch, offsets, count, scoreCorrection, result);
245257
} catch (Throwable e) {
246258
if (e instanceof Error err) {
247259
throw err;

libs/simdvec/native/src/vec/c/aarch64/vec_1.cpp

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,42 @@ EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
100100
return res;
101101
}
102102

103+
static inline f32_t adjust(f32_t raw_score, f32_t score_correction, f32_t f32_t first_offset, f32_t second_offset) {
104+
auto adjusted_score = raw_score * score_correction + first_offset + second_offset;
105+
return std::max((1.0f + adjusted_score) / 2.0f, 0.0f);
106+
}
107+
108+
static inline int_bits_to_float(const int32_t v) {
109+
union {
110+
int i;
111+
float f;
112+
} u;
113+
u.i = (long)v;
114+
return (f32_t)u.f;
115+
}
116+
103117
template <int32_t(*mapper)(const int32_t, const int32_t*)>
104-
static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
118+
static inline void dot7u_inner_bulk(
119+
const int8_t* a,
120+
const int8_t* b,
121+
const int32_t dims,
122+
const int32_t pitch,
123+
const int32_t* offsets,
124+
const int32_t count,
125+
const f32_t score_correction,
126+
f32_t* results
127+
) {
105128
size_t blk = dims & ~15;
106129
size_t c = 0;
107130

131+
// f32_t first_offset = int_bits_to_float(*((const int32_t*)(b + dims)));
132+
108133
// Process 4 vectors at a time
109134
for (; c + 3 < count; c += 4) {
110-
const int8_t* a0 = a + mapper(c, offsets) * dims;
111-
const int8_t* a1 = a + mapper(c + 1, offsets) * dims;
112-
const int8_t* a2 = a + mapper(c + 2, offsets) * dims;
113-
const int8_t* a3 = a + mapper(c + 3, offsets) * dims;
135+
const int8_t* a0 = a + mapper(c, offsets) * pitch;
136+
const int8_t* a1 = a + mapper(c + 1, offsets) * pitch;
137+
const int8_t* a2 = a + mapper(c + 2, offsets) * pitch;
138+
const int8_t* a3 = a + mapper(c + 3, offsets) * pitch;
114139

115140
int32x4_t acc0 = vdupq_n_s32(0);
116141
int32x4_t acc1 = vdupq_n_s32(0);
@@ -167,6 +192,7 @@ static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int3
167192
acc_scalar3 += a3[t] * bb;
168193
}
169194
}
195+
// f32_t second_offset_0 = int_bits_to_float(*((const int32_t*)(a0 + dims)));
170196
results[c + 0] = (f32_t)acc_scalar0;
171197
results[c + 1] = (f32_t)acc_scalar1;
172198
results[c + 2] = (f32_t)acc_scalar2;
@@ -175,7 +201,7 @@ static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int3
175201

176202
// Tail-handling: remaining 0..3 vectors
177203
for (; c < count; c++) {
178-
const int8_t* a0 = a + mapper(c, offsets) * dims;
204+
const int8_t* a0 = a + mapper(c, offsets) * pitch;
179205
results[c] = (f32_t)vec_dot7u(a0, b, dims);
180206
}
181207
}
@@ -189,12 +215,20 @@ static inline int index(const int32_t i, const int32_t* offsets) {
189215
}
190216

191217
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
192-
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
218+
dot7u_inner_bulk<identity>(a, b, dims, dims, NULL, count, 1.0f, results);
193219
}
194220

195221

196-
EXPORT void vec_dot7u_bulk_offsets(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
197-
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
222+
EXPORT void vec_dot7u_bulk_offsets(
223+
const int8_t* a,
224+
const int8_t* b,
225+
const int32_t dims,
226+
const int32_t pitch,
227+
const int32_t* offsets,
228+
const int32_t count,
229+
const f32_t score_correction,
230+
f32_t* results) {
231+
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, score_correction, results);
198232
}
199233

200234
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {

libs/simdvec/native/src/vec/c/amd64/vec_1.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -158,12 +158,21 @@ EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
158158
}
159159

160160
template <int32_t(*mapper)(int32_t, const int32_t*)>
161-
static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
161+
static inline void dot7u_inner_bulk(
162+
const int8_t* a,
163+
const int8_t* b,
164+
const int32_t dims,
165+
const int32_t pitch,
166+
const int32_t* offsets,
167+
const int32_t count,
168+
const f32_t score_correction, // TODO
169+
f32_t* results
170+
) {
162171
int32_t res = 0;
163172
if (dims > STRIDE_BYTES_LEN) {
164173
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
165174
for (int32_t c = 0; c < count; c++) {
166-
const int8_t* a0 = a + mapper(c, offsets) * dims;
175+
const int8_t* a0 = a + mapper(c, offsets) * pitch;
167176
int i = limit;
168177
res = dot7u_inner(a, b, i);
169178
for (; i < dims; i++) {
@@ -173,7 +182,7 @@ static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int3
173182
}
174183
} else {
175184
for (int32_t c = 0; c < count; c++) {
176-
const int8_t* a0 = a + mapper(c, offsets) * dims;
185+
const int8_t* a0 = a + mapper(c, offsets) * pitch;
177186
res = 0;
178187
for (int32_t i = 0; i < dims; i++) {
179188
res += a0[i] * b[i];
@@ -192,12 +201,20 @@ static inline int index(const int32_t i, const int32_t* offsets) {
192201
}
193202

194203
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
195-
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
204+
dot7u_inner_bulk<identity>(a, b, dims, dims, NULL, count, 1.0f, results);
196205
}
197206

198207

199-
EXPORT void vec_dot7u_bulk_offsets(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
200-
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
208+
EXPORT void vec_dot7u_bulk_offsets(
209+
const int8_t* a,
210+
const int8_t* b,
211+
const int32_t dims,
212+
const int32_t pitch,
213+
const int32_t* offsets,
214+
const int32_t count,
215+
const f32_t score_correction,
216+
f32_t* results) {
217+
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, score_correction, results);
201218
}
202219

203220
static inline int32_t sqr7u_inner(int8_t *a, int8_t *b, const int32_t dims) {

libs/simdvec/native/src/vec/c/amd64/vec_2.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,21 @@ EXPORT int32_t vec_dot7u_2(const int8_t* a, const int8_t* b, const int32_t dims)
124124
}
125125

126126
template <int32_t(*mapper)(int32_t, const int32_t*)>
127-
static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
127+
static inline void dot7u_inner_bulk(
128+
const int8_t* a,
129+
const int8_t* b,
130+
const int32_t dims,
131+
const int32_t pitch,
132+
const int32_t* offsets,
133+
const int32_t count,
134+
const f32_t score_correction, // TODO
135+
f32_t* results
136+
) {
128137
int32_t res = 0;
129138
if (dims > STRIDE_BYTES_LEN) {
130139
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
131140
for (int32_t c = 0; c < count; c++) {
132-
const int8_t* a0 = a + mapper(c, offsets) * dims;
141+
const int8_t* a0 = a + mapper(c, offsets) * pitch;
133142
int i = limit;
134143
res = dot7u_inner_avx512(a, b, i);
135144
for (; i < dims; i++) {
@@ -139,7 +148,7 @@ static inline void dot7u_inner_bulk(const int8_t* a, const int8_t* b, const int3
139148
}
140149
} else {
141150
for (int32_t c = 0; c < count; c++) {
142-
const int8_t* a0 = a + mapper(c, offsets) * dims;
151+
const int8_t* a0 = a + mapper(c, offsets) * pitch;
143152
res = 0;
144153
for (int32_t i = 0; i < dims; i++) {
145154
res += a0[i] * b[i];
@@ -157,13 +166,21 @@ static inline int index(const int32_t i, const int32_t* offsets) {
157166
return offsets[i];
158167
}
159168

160-
EXPORT void vec_dot7u_bulk_2(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
161-
dot7u_inner_bulk<identity>(a, b, dims, NULL, count, results);
169+
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
170+
dot7u_inner_bulk<identity>(a, b, dims, dims, NULL, count, 1.0f, results);
162171
}
163172

164173

165-
EXPORT void vec_dot7u_bulk_offsets_2(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results) {
166-
dot7u_inner_bulk<index>(a, b, dims, offsets, count, results);
174+
EXPORT void vec_dot7u_bulk_offsets(
175+
const int8_t* a,
176+
const int8_t* b,
177+
const int32_t dims,
178+
const int32_t pitch,
179+
const int32_t* offsets,
180+
const int32_t count,
181+
const f32_t score_correction,
182+
f32_t* results) {
183+
dot7u_inner_bulk<index>(a, b, dims, pitch, offsets, count, score_correction, results);
167184
}
168185

169186
template<int offsetRegs>

0 commit comments

Comments
 (0)