Skip to content

Commit 451b924

Browse files
committed
Fixes to function signatures and CPP code
1 parent 3369521 commit 451b924

File tree

8 files changed

+151
-73
lines changed

8 files changed

+151
-73
lines changed

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
9898
);
9999
dot7uBulkWithOffsets$mh = downcallHandle(
100100
"vec_dot7u_bulk_offsets",
101-
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS),
101+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS, JAVA_INT, JAVA_FLOAT, ADDRESS),
102102
LinkerHelperUtil.critical()
103103
);
104104
sqr7u$mh = downcallHandle(
@@ -329,16 +329,21 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
329329
mt = MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class);
330330
DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);
331331

332-
mt = MethodType.methodType(
333-
void.class,
334-
MemorySegment.class,
335-
MemorySegment.class,
336-
int.class,
337-
MemorySegment.class,
338-
int.class,
339-
MemorySegment.class
332+
DOT_HANDLE_7U_BULK_WITH_OFFSETS = lookup.findStatic(
333+
JdkVectorSimilarityFunctions.class,
334+
"dotProduct7uBulkWithOffsets",
335+
MethodType.methodType(
336+
void.class,
337+
MemorySegment.class,
338+
MemorySegment.class,
339+
int.class,
340+
int.class,
341+
MemorySegment.class,
342+
int.class,
343+
float.class,
344+
MemorySegment.class
345+
)
340346
);
341-
DOT_HANDLE_7U_BULK_WITH_OFFSETS = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulkWithOffsets", mt);
342347

343348
mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
344349
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 67 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.junit.BeforeClass;
1717

1818
import java.lang.foreign.MemorySegment;
19+
import java.lang.foreign.ValueLayout;
1920

2021
import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED;
2122
import static org.hamcrest.Matchers.containsString;
@@ -123,6 +124,60 @@ public void testInt7uBulkWithOffsets() {
123124
final int dims = size;
124125
final int numVecs = randomIntBetween(2, 101);
125126
var offsets = new int[numVecs];
127+
var vectors = new byte[numVecs][dims];
128+
var vectorsSegment = arena.allocate((long) dims * numVecs);
129+
var offsetsSegment = arena.allocate((long) numVecs * Integer.BYTES);
130+
for (int i = 0; i < numVecs; i++) {
131+
offsets[i] = randomInt(numVecs - 1);
132+
offsetsSegment.setAtIndex(ValueLayout.JAVA_INT, i, offsets[i]);
133+
randomBytesBetween(vectors[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
134+
MemorySegment.copy(vectors[i], 0, vectorsSegment, ValueLayout.JAVA_BYTE, (long) i * dims, dims);
135+
}
136+
int queryOrd = randomInt(numVecs - 1);
137+
float[] expectedScores = new float[numVecs];
138+
dotProductBulkWithOffsetsScalar(vectors[queryOrd], vectors, offsets, expectedScores);
139+
140+
var nativeQuerySeg = vectorsSegment.asSlice((long) queryOrd * dims, dims);
141+
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
142+
143+
dotProduct7uBulkWithOffsets(vectorsSegment, nativeQuerySeg, dims, dims, offsetsSegment, numVecs, 1.0f, bulkScoresSeg);
144+
assertScoresEquals(expectedScores, bulkScoresSeg);
145+
}
146+
147+
public void testInt7uBulkWithOffsetsAndPitch() {
148+
assumeTrue(notSupportedMsg(), supported());
149+
final int dims = size;
150+
final int numVecs = randomIntBetween(2, 101);
151+
var offsets = new int[numVecs];
152+
var vectors = new byte[numVecs][dims];
153+
154+
// Mimics extra data at the end
155+
var pitch = dims * Byte.BYTES + Float.BYTES;
156+
var vectorsSegment = arena.allocate((long) numVecs * pitch);
157+
var offsetsSegment = arena.allocate((long) numVecs * Integer.BYTES);
158+
for (int i = 0; i < numVecs; i++) {
159+
offsets[i] = randomInt(numVecs - 1);
160+
offsetsSegment.setAtIndex(ValueLayout.JAVA_INT, i, offsets[i]);
161+
randomBytesBetween(vectors[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
162+
MemorySegment.copy(vectors[i], 0, vectorsSegment, ValueLayout.JAVA_BYTE, (long) i * pitch, dims);
163+
}
164+
int queryOrd = randomInt(numVecs - 1);
165+
float[] expectedScores = new float[numVecs];
166+
dotProductBulkWithOffsetsScalar(vectors[queryOrd], vectors, offsets, expectedScores);
167+
168+
var nativeQuerySeg = vectorsSegment.asSlice((long) queryOrd * pitch, pitch);
169+
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
170+
171+
dotProduct7uBulkWithOffsets(vectorsSegment, nativeQuerySeg, dims, pitch, offsetsSegment, numVecs, 1.0f, bulkScoresSeg);
172+
assertScoresEquals(expectedScores, bulkScoresSeg);
173+
}
174+
175+
public void testInt7uBulkWithOffsetsHeapSegments() {
176+
assumeTrue(notSupportedMsg(), supported());
177+
assumeTrue("Requires support for heap MemorySegments", supportsHeapSegments());
178+
final int dims = size;
179+
final int numVecs = randomIntBetween(2, 101);
180+
var offsets = new int[numVecs];
126181
var values = new byte[numVecs][dims];
127182
var segment = arena.allocate((long) dims * numVecs);
128183
for (int i = 0; i < numVecs; i++) {
@@ -135,26 +190,19 @@ public void testInt7uBulkWithOffsets() {
135190
dotProductBulkWithOffsetsScalar(values[queryOrd], values, offsets, expectedScores);
136191

137192
var nativeQuerySeg = segment.asSlice((long) queryOrd * dims, dims);
138-
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
139-
var offsetsSeg = arena.allocate((long) numVecs * Integer.BYTES);
140-
// TODO: test pitch
141-
dotProduct7uBulkWithOffsets(segment, nativeQuerySeg, dims, dims, offsetsSeg, numVecs, 1.0f, bulkScoresSeg);
142-
assertScoresEquals(expectedScores, bulkScoresSeg);
143193

144-
if (supportsHeapSegments()) {
145-
float[] bulkScores = new float[numVecs];
146-
dotProduct7uBulkWithOffsets(
147-
segment,
148-
nativeQuerySeg,
149-
dims,
150-
dims,
151-
MemorySegment.ofArray(offsets),
152-
numVecs,
153-
1.0f,
154-
MemorySegment.ofArray(bulkScores)
155-
);
156-
assertArrayEquals(expectedScores, bulkScores, 0f);
157-
}
194+
float[] bulkScores = new float[numVecs];
195+
dotProduct7uBulkWithOffsets(
196+
segment,
197+
nativeQuerySeg,
198+
dims,
199+
dims,
200+
MemorySegment.ofArray(offsets),
201+
numVecs,
202+
1.0f,
203+
MemorySegment.ofArray(bulkScores)
204+
);
205+
assertArrayEquals(expectedScores, bulkScores, 0f);
158206
}
159207

160208
public void testIllegalDims() {

libs/simdvec/native/src/vec/c/aarch64/vec_1.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
100100
return res;
101101
}
102102

103-
static inline f32_t adjust(f32_t raw_score, f32_t score_correction, f32_t f32_t first_offset, f32_t second_offset) {
103+
static inline f32_t adjust(f32_t raw_score, f32_t score_correction, f32_t first_offset, f32_t second_offset) {
104104
auto adjusted_score = raw_score * score_correction + first_offset + second_offset;
105-
return std::max((1.0f + adjusted_score) / 2.0f, 0.0f);
105+
return fmaxf((1.0f + adjusted_score) / 2.0f, 0.0f);
106106
}
107107

108-
static inline int_bits_to_float(const int32_t v) {
108+
static inline f32_t int_bits_to_float(const int32_t v) {
109109
union {
110110
int i;
111111
float f;
@@ -114,7 +114,7 @@ static inline int_bits_to_float(const int32_t v) {
114114
return (f32_t)u.f;
115115
}
116116

117-
template <int32_t(*mapper)(const int32_t, const int32_t*)>
117+
template <int64_t(*mapper)(const int32_t, const int32_t*)>
118118
static inline void dot7u_inner_bulk(
119119
const int8_t* a,
120120
const int8_t* b,
@@ -206,11 +206,11 @@ static inline void dot7u_inner_bulk(
206206
}
207207
}
208208

209-
static inline int identity(const int32_t i, const int32_t* offsets) {
209+
static inline int64_t identity(const int32_t i, const int32_t* offsets) {
210210
return i;
211211
}
212212

213-
static inline int index(const int32_t i, const int32_t* offsets) {
213+
static inline int64_t index(const int32_t i, const int32_t* offsets) {
214214
return offsets[i];
215215
}
216216

libs/simdvec/native/src/vec/c/amd64/vec_1.cpp

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims) {
157157
return res;
158158
}
159159

160-
template <int32_t(*mapper)(int32_t, const int32_t*)>
160+
template <int64_t(*mapper)(int32_t, const int32_t*)>
161161
static inline void dot7u_inner_bulk(
162162
const int8_t* a,
163163
const int8_t* b,
@@ -168,22 +168,21 @@ static inline void dot7u_inner_bulk(
168168
const f32_t score_correction, // TODO
169169
f32_t* results
170170
) {
171-
int32_t res = 0;
172171
if (dims > STRIDE_BYTES_LEN) {
173172
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
174173
for (int32_t c = 0; c < count; c++) {
175-
const int8_t* a0 = a + mapper(c, offsets) * pitch;
174+
const int8_t* a0 = a + (mapper(c, offsets) * pitch);
176175
int i = limit;
177-
res = dot7u_inner(a, b, i);
176+
int32_t res = dot7u_inner(a0, b, i);
178177
for (; i < dims; i++) {
179178
res += a0[i] * b[i];
180179
}
181180
results[c] = (f32_t)res;
182181
}
183182
} else {
184183
for (int32_t c = 0; c < count; c++) {
185-
const int8_t* a0 = a + mapper(c, offsets) * pitch;
186-
res = 0;
184+
const int8_t* a0 = a + (mapper(c, offsets) * pitch);
185+
int32_t res = 0;
187186
for (int32_t i = 0; i < dims; i++) {
188187
res += a0[i] * b[i];
189188
}
@@ -192,11 +191,11 @@ static inline void dot7u_inner_bulk(
192191
}
193192
}
194193

195-
static inline int identity(const int32_t i, const int32_t* offsets) {
194+
static inline int64_t identity(const int32_t i, const int32_t* offsets) {
196195
return i;
197196
}
198197

199-
static inline int index(const int32_t i, const int32_t* offsets) {
198+
static inline int64_t index(const int32_t i, const int32_t* offsets) {
200199
return offsets[i];
201200
}
202201

libs/simdvec/native/src/vec/c/amd64/vec_2.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ EXPORT int32_t vec_dot7u_2(const int8_t* a, const int8_t* b, const int32_t dims)
123123
return res;
124124
}
125125

126-
template <int32_t(*mapper)(int32_t, const int32_t*)>
126+
template <int64_t(*mapper)(int32_t, const int32_t*)>
127127
static inline void dot7u_inner_bulk(
128128
const int8_t* a,
129129
const int8_t* b,
@@ -134,22 +134,21 @@ static inline void dot7u_inner_bulk(
134134
const f32_t score_correction, // TODO
135135
f32_t* results
136136
) {
137-
int32_t res = 0;
138137
if (dims > STRIDE_BYTES_LEN) {
139138
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
140139
for (int32_t c = 0; c < count; c++) {
141-
const int8_t* a0 = a + mapper(c, offsets) * pitch;
140+
const int8_t* a0 = a + (mapper(c, offsets) * pitch);
142141
int i = limit;
143-
res = dot7u_inner_avx512(a, b, i);
142+
int32_t res = dot7u_inner_avx512(a0, b, i);
144143
for (; i < dims; i++) {
145144
res += a0[i] * b[i];
146145
}
147146
results[c] = (f32_t)res;
148147
}
149148
} else {
150149
for (int32_t c = 0; c < count; c++) {
151-
const int8_t* a0 = a + mapper(c, offsets) * pitch;
152-
res = 0;
150+
const int8_t* a0 = a + (mapper(c, offsets) * pitch);
151+
int32_t res = 0;
153152
for (int32_t i = 0; i < dims; i++) {
154153
res += a0[i] * b[i];
155154
}
@@ -158,20 +157,20 @@ static inline void dot7u_inner_bulk(
158157
}
159158
}
160159

161-
static inline int identity(const int32_t i, const int32_t* offsets) {
160+
static inline int64_t identity(const int32_t i, const int32_t* offsets) {
162161
return i;
163162
}
164163

165-
static inline int index(const int32_t i, const int32_t* offsets) {
164+
static inline int64_t index(const int32_t i, const int32_t* offsets) {
166165
return offsets[i];
167166
}
168167

169-
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
168+
EXPORT void vec_dot7u_bulk_2(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results) {
170169
dot7u_inner_bulk<identity>(a, b, dims, dims, NULL, count, 1.0f, results);
171170
}
172171

173172

174-
EXPORT void vec_dot7u_bulk_offsets(
173+
EXPORT void vec_dot7u_bulk_offsets_2(
175174
const int8_t* a,
176175
const int8_t* b,
177176
const int32_t dims,

libs/simdvec/native/src/vec/headers/vec.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,15 @@ EXPORT int32_t vec_dot7u(const int8_t* a, const int8_t* b, const int32_t dims);
4141

4242
EXPORT void vec_dot7u_bulk(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, f32_t* results);
4343

44-
EXPORT void vec_dot7u_bulk_offsets(const int8_t* a, const int8_t* b, const int32_t dims, const int32_t* offsets, const int32_t count, f32_t* results);
44+
EXPORT void vec_dot7u_bulk_offsets(
45+
const int8_t* a,
46+
const int8_t* b,
47+
const int32_t dims,
48+
const int32_t pitch,
49+
const int32_t* offsets,
50+
const int32_t count,
51+
const f32_t score_correction,
52+
f32_t* results);
4553

4654
EXPORT int32_t vec_sqr7u(int8_t *a, int8_t *b, const int32_t length);
4755

0 commit comments

Comments
 (0)