Skip to content

Commit ae8737b

Browse files
GPU Cuvs snapshot update (elastic#136057)
This updates GPU Plugin to a newer CUVS version Co-authored-by: Brian Seeders <[email protected]>
1 parent 17b2ab6 commit ae8737b

File tree

5 files changed

+85
-6
lines changed

5 files changed

+85
-6
lines changed

gradle/verification-metadata.xml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1199,9 +1199,9 @@
11991199
<sha256 value="64fab42f17bf8e0efb193dd34da716ef7abb7515234036119df1776b808dc066" origin="Generated by Gradle"/>
12001200
</artifact>
12011201
</component>
1202-
<component group="com.nvidia.cuvs" name="cuvs-java" version="25.10.0">
1203-
<artifact name="cuvs-java-25.10.0.jar">
1204-
<sha256 value="2d4cca3b6b6c7c4d3c1a2f57c00cb55f7a45699536ad356699f32bcea7714539" origin="Generated by Gradle"/>
1202+
<component group="com.nvidia.cuvs" name="cuvs-java" version="25.10.0-815d86dd">
1203+
<artifact name="cuvs-java-25.10.0-815d86dd.jar">
1204+
<sha256 value="b15a5f63b7cc2349444ee5470dfe7a316ccd11b6fcc4be3dd4b11aaeb2ae65fe" origin="Generated by Gradle"/>
12051205
</artifact>
12061206
</component>
12071207
<component group="com.perforce" name="p4java" version="2015.2.1365273">

x-pack/plugin/gpu/build.gradle

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ repositories {
2222
dependencies {
2323
compileOnly project(path: xpackModule('core'))
2424
compileOnly project(':server')
25-
implementation('com.nvidia.cuvs:cuvs-java:25.10.0') {
26-
changing = true // Ensure that we get updates even when the version number doesn't change. We can remove this once things stabilize
27-
}
25+
implementation('com.nvidia.cuvs:cuvs-java:25.10.0-815d86dd')
2826
testImplementation(testArtifact(project(xpackModule('core'))))
2927
testImplementation(testArtifact(project(':server')))
3028
yamlRestTestImplementation(project(xpackModule('gpu')))

x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/CuVSProviderDelegate.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(long l, long l1,
3838
return delegate.newHostMatrixBuilder(l, l1, dataType);
3939
}
4040

41+
@Override
42+
public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(
43+
long size,
44+
long columns,
45+
int rowStride,
46+
int columnStride,
47+
CuVSMatrix.DataType dataType
48+
) {
49+
return delegate.newHostMatrixBuilder(size, columns, rowStride, columnStride, dataType);
50+
}
51+
4152
@Override
4253
public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
4354
CuVSResources cuVSResources,
@@ -65,6 +76,11 @@ public MethodHandle newNativeMatrixBuilder() {
6576
return delegate.newNativeMatrixBuilder();
6677
}
6778

79+
@Override
80+
public MethodHandle newNativeMatrixBuilderWithStrides() {
81+
return delegate.newNativeMatrixBuilderWithStrides();
82+
}
83+
6884
@Override
6985
public CuVSMatrix newMatrixFromArray(float[][] floats) {
7086
return delegate.newMatrixFromArray(floats);

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ static DatasetUtils getInstance() {
2222
/** Returns a Dataset over the vectors of type {@code dataType} in the input. */
2323
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;
2424

25+
CuVSMatrix fromInput(
26+
MemorySegmentAccessInput input,
27+
int numVectors,
28+
int dims,
29+
int rowStride,
30+
int columnStride,
31+
CuVSMatrix.DataType dataType
32+
) throws IOException;
33+
2534
/** Returns a Dataset over an input slice */
2635
CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
2736
throws IOException;

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class DatasetUtilsImpl implements DatasetUtils {
2121
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();
2222

2323
private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
24+
private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides();
2425

2526
static DatasetUtils getInstance() {
2627
return INSTANCE;
@@ -40,6 +41,27 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d
4041
}
4142
}
4243

44+
static CuVSMatrix fromMemorySegment(
45+
MemorySegment memorySegment,
46+
int size,
47+
int dimensions,
48+
int rowStride,
49+
int columnStride,
50+
CuVSMatrix.DataType dataType
51+
) {
52+
try {
53+
return (CuVSMatrix) createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType);
54+
} catch (Throwable e) {
55+
if (e instanceof Error err) {
56+
throw err;
57+
} else if (e instanceof RuntimeException re) {
58+
throw re;
59+
} else {
60+
throw new RuntimeException(e);
61+
}
62+
}
63+
}
64+
4365
private DatasetUtilsImpl() {}
4466

4567
@Override
@@ -50,6 +72,21 @@ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int
5072
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
5173
}
5274

75+
@Override
76+
public CuVSMatrix fromInput(
77+
MemorySegmentAccessInput input,
78+
int numVectors,
79+
int dims,
80+
int rowStride,
81+
int columnStride,
82+
CuVSMatrix.DataType dataType
83+
) throws IOException {
84+
if (numVectors < 0 || dims < 0) {
85+
throwIllegalArgumentException(numVectors, dims);
86+
}
87+
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType);
88+
}
89+
5390
@Override
5491
public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
5592
throws IOException {
@@ -76,6 +113,25 @@ private static CuVSMatrix createCuVSMatrix(
76113
return fromMemorySegment(ms, numVectors, dims, dataType);
77114
}
78115

116+
private static CuVSMatrix createCuVSMatrix(
117+
MemorySegmentAccessInput input,
118+
long pos,
119+
long len,
120+
int numVectors,
121+
int dims,
122+
int rowStride,
123+
int columnStride,
124+
CuVSMatrix.DataType dataType
125+
) throws IOException {
126+
MemorySegment ms = input.segmentSliceOrNull(pos, len);
127+
assert ms != null;
128+
final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES;
129+
if (((long) numVectors * rowStride * byteSize) > ms.byteSize()) {
130+
throwIllegalArgumentException(ms, numVectors, dims);
131+
}
132+
return fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType);
133+
}
134+
79135
static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
80136
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
81137
throw new IllegalArgumentException(s);

0 commit comments

Comments
 (0)