Skip to content

Commit 329007b

Browse files
authored
Merge branch 'main' into exphisto-null-sum
2 parents d865d06 + 0b3b48a commit 329007b

File tree

47 files changed

+414
-164
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

47 files changed

+414
-164
lines changed

benchmarks/src/main/java/org/elasticsearch/benchmark/_nightly/esql/QueryPlanningBenchmark.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import java.util.LinkedHashMap;
5353
import java.util.Locale;
5454
import java.util.Map;
55+
import java.util.Set;
5556
import java.util.concurrent.TimeUnit;
5657

5758
import static java.util.Collections.emptyMap;
@@ -101,7 +102,7 @@ public void setup() {
101102
mapping.put("field" + i, new EsField("field-" + i, TEXT, emptyMap(), true, EsField.TimeSeriesFieldType.NONE));
102103
}
103104

104-
var esIndex = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD));
105+
var esIndex = new EsIndex("test", mapping, Map.of("test", IndexMode.STANDARD), Set.of());
105106

106107
var functionRegistry = new EsqlFunctionRegistry();
107108

libs/native/libraries/build.gradle

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ configurations {
1919
}
2020

2121
var zstdVersion = "1.5.5"
22-
var vecVersion = "1.0.13"
22+
var vecVersion = "1.0.15"
2323

2424
repositories {
2525
exclusiveContent {

libs/native/src/main/java/org/elasticsearch/nativeaccess/VectorSimilarityFunctions.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,23 @@ public interface VectorSimilarityFunctions {
3030
*/
3131
MethodHandle dotProductHandle7u();
3232

33+
/**
34+
* Produces a method handle which computes the dot product of several byte (unsigned
35+
* int7) vectors. This bulk operation can be used to compute the dot product between a
36+
* single query vector and a number of other vectors.
37+
*
38+
* <p> Unsigned int7 byte vectors have values in the range of 0 to 127 (inclusive).
39+
*
40+
* <p> The type of the method handle will have {@code void} as return type. The type of
41+
* its first and second arguments will be {@code MemorySegment}, the former contains the
42+
* vector data bytes for several vectors, while the latter just a single vector. The
43+
* type of the third argument is an int, representing the dimensions of each vector. The
44+
* type of the fourth argument is an int, representing the number of vectors in the
45+
* first argument. The type of the final argument is a MemorySegment, into which the
46+
* computed dot product float values will be stored.
47+
*/
48+
MethodHandle dotProductHandle7uBulk();
49+
3350
/**
3451
* Produces a method handle returning the square distance of byte (unsigned int7) vectors.
3552
*

libs/native/src/main/java/org/elasticsearch/nativeaccess/jdk/JdkVectorLibrary.java

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
3232
static final Logger logger = LogManager.getLogger(JdkVectorLibrary.class);
3333

3434
static final MethodHandle dot7u$mh;
35+
static final MethodHandle dot7uBulk$mh;
3536
static final MethodHandle sqr7u$mh;
3637
static final MethodHandle cosf32$mh;
3738
static final MethodHandle dotf32$mh;
@@ -53,6 +54,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
5354
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
5455
LinkerHelperUtil.critical()
5556
);
57+
dot7uBulk$mh = downcallHandle(
58+
"dot7u_bulk_2",
59+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
60+
LinkerHelperUtil.critical()
61+
);
5662
sqr7u$mh = downcallHandle(
5763
"sqr7u_2",
5864
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -79,6 +85,11 @@ public final class JdkVectorLibrary implements VectorLibrary {
7985
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
8086
LinkerHelperUtil.critical()
8187
);
88+
dot7uBulk$mh = downcallHandle(
89+
"dot7u_bulk",
90+
FunctionDescriptor.ofVoid(ADDRESS, ADDRESS, JAVA_INT, JAVA_INT, ADDRESS),
91+
LinkerHelperUtil.critical()
92+
);
8293
sqr7u$mh = downcallHandle(
8394
"sqr7u",
8495
FunctionDescriptor.of(JAVA_INT, ADDRESS, ADDRESS, JAVA_INT),
@@ -108,6 +119,7 @@ public final class JdkVectorLibrary implements VectorLibrary {
108119
enable them in your OS/Hypervisor/VM/container""");
109120
}
110121
dot7u$mh = null;
122+
dot7uBulk$mh = null;
111123
sqr7u$mh = null;
112124
cosf32$mh = null;
113125
dotf32$mh = null;
@@ -142,6 +154,13 @@ static int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
142154
return dot7u(a, b, length);
143155
}
144156

157+
static void dotProduct7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
158+
Objects.checkFromIndexSize(0, length * count, (int) a.byteSize());
159+
Objects.checkFromIndexSize(0, length, (int) b.byteSize());
160+
Objects.checkFromIndexSize(0, count * Float.BYTES, (int) result.byteSize());
161+
dot7uBulk(a, b, length, count, result);
162+
}
163+
145164
/**
146165
* Computes the square distance of given unsigned int7 byte vectors.
147166
*
@@ -210,6 +229,14 @@ private static int dot7u(MemorySegment a, MemorySegment b, int length) {
210229
}
211230
}
212231

232+
private static void dot7uBulk(MemorySegment a, MemorySegment b, int length, int count, MemorySegment result) {
233+
try {
234+
JdkVectorLibrary.dot7uBulk$mh.invokeExact(a, b, length, count, result);
235+
} catch (Throwable t) {
236+
throw new AssertionError(t);
237+
}
238+
}
239+
213240
private static int sqr7u(MemorySegment a, MemorySegment b, int length) {
214241
try {
215242
return (int) JdkVectorLibrary.sqr7u$mh.invokeExact(a, b, length);
@@ -243,6 +270,7 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
243270
}
244271

245272
static final MethodHandle DOT_HANDLE_7U;
273+
static final MethodHandle DOT_HANDLE_7U_BULK;
246274
static final MethodHandle SQR_HANDLE_7U;
247275
static final MethodHandle COS_HANDLE_FLOAT32;
248276
static final MethodHandle DOT_HANDLE_FLOAT32;
@@ -255,6 +283,9 @@ private static float sqrf32(MemorySegment a, MemorySegment b, int length) {
255283
DOT_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7u", mt);
256284
SQR_HANDLE_7U = lookup.findStatic(JdkVectorSimilarityFunctions.class, "squareDistance7u", mt);
257285

286+
mt = MethodType.methodType(void.class, MemorySegment.class, MemorySegment.class, int.class, int.class, MemorySegment.class);
287+
DOT_HANDLE_7U_BULK = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProduct7uBulk", mt);
288+
258289
mt = MethodType.methodType(float.class, MemorySegment.class, MemorySegment.class, int.class);
259290
COS_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "cosineF32", mt);
260291
DOT_HANDLE_FLOAT32 = lookup.findStatic(JdkVectorSimilarityFunctions.class, "dotProductF32", mt);
@@ -269,6 +300,11 @@ public MethodHandle dotProductHandle7u() {
269300
return DOT_HANDLE_7U;
270301
}
271302

303+
@Override
304+
public MethodHandle dotProductHandle7uBulk() {
305+
return DOT_HANDLE_7U_BULK;
306+
}
307+
272308
@Override
273309
public MethodHandle squareDistanceHandle7u() {
274310
return SQR_HANDLE_7U;

libs/native/src/test/java/org/elasticsearch/nativeaccess/jdk/JDKVectorLibraryInt7uTests.java

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import java.lang.foreign.MemorySegment;
1919

20+
import static java.lang.foreign.ValueLayout.JAVA_FLOAT_UNALIGNED;
2021
import static org.hamcrest.Matchers.containsString;
2122

2223
public class JDKVectorLibraryInt7uTests extends VectorSimilarityFunctionsTests {
@@ -71,6 +72,11 @@ public void testInt7BinaryVectors() {
7172
assertEquals(expected, dotProduct7u(heapSeg1, heapSeg2, dims));
7273
assertEquals(expected, dotProduct7u(nativeSeg1, heapSeg2, dims));
7374
assertEquals(expected, dotProduct7u(heapSeg1, nativeSeg2, dims));
75+
76+
// trivial bulk with a single vector
77+
float[] bulkScore = new float[1];
78+
dotProduct7uBulk(nativeSeg1, nativeSeg2, dims, 1, MemorySegment.ofArray(bulkScore));
79+
assertEquals(expected, bulkScore[0], 0f);
7480
}
7581

7682
// square distance
@@ -86,6 +92,32 @@ public void testInt7BinaryVectors() {
8692
}
8793
}
8894

95+
public void testInt7uBulk() {
96+
assumeTrue(notSupportedMsg(), supported());
97+
final int dims = size;
98+
final int numVecs = randomIntBetween(2, 101);
99+
var values = new byte[numVecs][dims];
100+
var segment = arena.allocate((long) dims * numVecs);
101+
for (int i = 0; i < numVecs; i++) {
102+
randomBytesBetween(values[i], MIN_INT7_VALUE, MAX_INT7_VALUE);
103+
MemorySegment.copy(MemorySegment.ofArray(values[i]), 0L, segment, (long) i * dims, dims);
104+
}
105+
int queryOrd = randomInt(numVecs - 1);
106+
float[] expectedScores = new float[numVecs];
107+
dotProductBulkScalar(values[queryOrd], values, expectedScores);
108+
109+
var nativeQuerySeg = segment.asSlice((long) queryOrd * dims, dims);
110+
var bulkScoresSeg = arena.allocate((long) numVecs * Float.BYTES);
111+
dotProduct7uBulk(segment, nativeQuerySeg, dims, numVecs, bulkScoresSeg);
112+
assertScoresEquals(expectedScores, bulkScoresSeg);
113+
114+
if (supportsHeapSegments()) {
115+
float[] bulkScores = new float[numVecs];
116+
dotProduct7uBulk(segment, nativeQuerySeg, dims, numVecs, MemorySegment.ofArray(bulkScores));
117+
assertArrayEquals(expectedScores, bulkScores, 0f);
118+
}
119+
}
120+
89121
public void testIllegalDims() {
90122
assumeTrue(notSupportedMsg(), supported());
91123
var segment = arena.allocate((long) size * 3);
@@ -109,6 +141,26 @@ public void testIllegalDims() {
109141
assertThat(e6.getMessage(), containsString("out of bounds for length"));
110142
}
111143

144+
public void testBulkIllegalDims() {
145+
assumeTrue(notSupportedMsg(), supported());
146+
var segA = arena.allocate((long) size * 3);
147+
var segB = arena.allocate((long) size * 3);
148+
var segS = arena.allocate((long) size * Float.BYTES);
149+
150+
var e1 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, size, 4, segS));
151+
assertThat(e1.getMessage(), containsString("out of bounds for length"));
152+
153+
var e2 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, size, -1, segS));
154+
assertThat(e2.getMessage(), containsString("out of bounds for length"));
155+
156+
var e3 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, -1, 3, segS));
157+
assertThat(e3.getMessage(), containsString("out of bounds for length"));
158+
159+
var tooSmall = arena.allocate((long) 3 * Float.BYTES - 1);
160+
var e4 = expectThrows(IOOBE, () -> dotProduct7uBulk(segA, segB, size, 3, tooSmall));
161+
assertThat(e4.getMessage(), containsString("out of bounds for length"));
162+
}
163+
112164
int dotProduct7u(MemorySegment a, MemorySegment b, int length) {
113165
try {
114166
return (int) getVectorDistance().dotProductHandle7u().invokeExact(a, b, length);
@@ -137,6 +189,20 @@ int squareDistance7u(MemorySegment a, MemorySegment b, int length) {
137189
}
138190
}
139191

192+
void dotProduct7uBulk(MemorySegment a, MemorySegment b, int dims, int count, MemorySegment result) {
193+
try {
194+
getVectorDistance().dotProductHandle7uBulk().invokeExact(a, b, dims, count, result);
195+
} catch (Throwable e) {
196+
if (e instanceof Error err) {
197+
throw err;
198+
} else if (e instanceof RuntimeException re) {
199+
throw re;
200+
} else {
201+
throw new RuntimeException(e);
202+
}
203+
}
204+
}
205+
140206
/** Computes the dot product of the given vectors a and b. */
141207
static int dotProductScalar(byte[] a, byte[] b) {
142208
int res = 0;
@@ -156,4 +222,18 @@ static int squareDistanceScalar(byte[] a, byte[] b) {
156222
}
157223
return squareSum;
158224
}
225+
226+
static void dotProductBulkScalar(byte[] query, byte[][] data, float[] scores) {
227+
for (int i = 0; i < data.length; i++) {
228+
scores[i] = dotProductScalar(query, data[i]);
229+
}
230+
}
231+
232+
static void assertScoresEquals(float[] expectedScores, MemorySegment expectedScoresSeg) {
233+
assert expectedScores.length == (expectedScoresSeg.byteSize() / Float.BYTES);
234+
for (int i = 0; i < expectedScores.length; i++) {
235+
assertEquals(expectedScores[i], expectedScoresSeg.get(JAVA_FLOAT_UNALIGNED, i * Float.BYTES), 0f);
236+
}
237+
}
238+
159239
}

libs/simdvec/native/publish_vec_binaries.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ if [ -z "$ARTIFACTORY_API_KEY" ]; then
2020
exit 1;
2121
fi
2222

23-
VERSION="1.0.14"
23+
VERSION="1.0.15"
2424
ARTIFACTORY_REPOSITORY="${ARTIFACTORY_REPOSITORY:-https://artifactory.elastic.dev/artifactory/elasticsearch-native/}"
2525
TEMP=$(mktemp -d)
2626

libs/simdvec/native/src/vec/c/aarch64/vec.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,10 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
9595
return res;
9696
}
9797

98-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results) {
98+
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
9999
int32_t res = 0;
100100
if (dims > DOT7U_STRIDE_BYTES_LEN) {
101-
int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
101+
const int limit = dims & ~(DOT7U_STRIDE_BYTES_LEN - 1);
102102
for (size_t c = 0; c < count; c++) {
103103
int i = limit;
104104
res = dot7u_inner(a, b, i);

libs/simdvec/native/src/vec/c/amd64/vec.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,10 +153,10 @@ EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims) {
153153
return res;
154154
}
155155

156-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results) {
156+
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
157157
int32_t res = 0;
158158
if (dims > STRIDE_BYTES_LEN) {
159-
int limit = dims & ~(STRIDE_BYTES_LEN - 1);
159+
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
160160
for (size_t c = 0; c < count; c++) {
161161
int i = limit;
162162
res = dot7u_inner(a, b, i);

libs/simdvec/native/src/vec/c/amd64/vec_2.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ inline __m512i fma8(__m512i acc, const int8_t* p1, const int8_t* p2) {
4747
return _mm512_add_epi32(_mm512_madd_epi16(ones, dot), acc);
4848
}
4949

50-
static inline int32_t dot7u_inner_avx512(int8_t* a, int8_t* b, size_t dims) {
50+
static inline int32_t dot7u_inner_avx512(int8_t* a, const int8_t* b, size_t dims) {
5151
constexpr int stride8 = 8 * STRIDE_BYTES_LEN;
5252
constexpr int stride4 = 4 * STRIDE_BYTES_LEN;
5353
const int8_t* p1 = a;
@@ -115,10 +115,10 @@ EXPORT int32_t dot7u_2(int8_t* a, int8_t* b, size_t dims) {
115115
}
116116

117117
extern "C"
118-
EXPORT void dot7u_bulk_2(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results) {
118+
EXPORT void dot7u_bulk_2(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results) {
119119
int32_t res = 0;
120120
if (dims > STRIDE_BYTES_LEN) {
121-
int limit = dims & ~(STRIDE_BYTES_LEN - 1);
121+
const int limit = dims & ~(STRIDE_BYTES_LEN - 1);
122122
for (size_t c = 0; c < count; c++) {
123123
int i = limit;
124124
res = dot7u_inner_avx512(a, b, i);

libs/simdvec/native/src/vec/headers/vec.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ EXPORT int vec_caps();
1919

2020
EXPORT int32_t dot7u(int8_t* a, int8_t* b, size_t dims);
2121

22-
EXPORT void dot7u_bulk(int8_t* a, int8_t* b, size_t dims, size_t count, float_t* results);
22+
EXPORT void dot7u_bulk(int8_t* a, const int8_t* b, const int32_t dims, const int32_t count, float_t* results);
2323

2424
EXPORT int32_t sqr7u(int8_t *a, int8_t *b, size_t length);
2525

0 commit comments

Comments
 (0)