diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml
index 5b5c9aea68ac3..334b488b16c11 100644
--- a/gradle/verification-metadata.xml
+++ b/gradle/verification-metadata.xml
@@ -1199,9 +1199,9 @@
-
-
-
+
+
+
diff --git a/x-pack/plugin/gpu/build.gradle b/x-pack/plugin/gpu/build.gradle
index 3b9330371fc47..cb52963f05546 100644
--- a/x-pack/plugin/gpu/build.gradle
+++ b/x-pack/plugin/gpu/build.gradle
@@ -22,9 +22,7 @@ repositories {
dependencies {
compileOnly project(path: xpackModule('core'))
compileOnly project(':server')
- implementation('com.nvidia.cuvs:cuvs-java:25.10.0') {
- changing = true // Ensure that we get updates even when the version number doesn't change. We can remove this once things stabilize
- }
+ implementation('com.nvidia.cuvs:cuvs-java:25.10.0-815d86dd')
testImplementation(testArtifact(project(xpackModule('core'))))
testImplementation(testArtifact(project(':server')))
yamlRestTestImplementation(project(xpackModule('gpu')))
diff --git a/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/CuVSProviderDelegate.java b/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/CuVSProviderDelegate.java
index d0f8e85ef6070..c4f17dca68adf 100644
--- a/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/CuVSProviderDelegate.java
+++ b/x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/CuVSProviderDelegate.java
@@ -38,6 +38,17 @@ public CuVSMatrix.Builder newHostMatrixBuilder(long l, long l1,
return delegate.newHostMatrixBuilder(l, l1, dataType);
}
+ @Override
+ public CuVSMatrix.Builder newHostMatrixBuilder(
+ long size,
+ long columns,
+ int rowStride,
+ int columnStride,
+ CuVSMatrix.DataType dataType
+ ) {
+ return delegate.newHostMatrixBuilder(size, columns, rowStride, columnStride, dataType);
+ }
+
@Override
public CuVSMatrix.Builder newDeviceMatrixBuilder(
CuVSResources cuVSResources,
@@ -65,6 +76,11 @@ public MethodHandle newNativeMatrixBuilder() {
return delegate.newNativeMatrixBuilder();
}
+ @Override
+ public MethodHandle newNativeMatrixBuilderWithStrides() {
+ return delegate.newNativeMatrixBuilderWithStrides();
+ }
+
@Override
public CuVSMatrix newMatrixFromArray(float[][] floats) {
return delegate.newMatrixFromArray(floats);
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java
index 3a9fcb2c68cd8..dc86f189cf585 100644
--- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java
@@ -22,6 +22,15 @@ static DatasetUtils getInstance() {
/** Returns a Dataset over the vectors of type {@code dataType} in the input. */
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;
+ CuVSMatrix fromInput(
+ MemorySegmentAccessInput input,
+ int numVectors,
+ int dims,
+ int rowStride,
+ int columnStride,
+ CuVSMatrix.DataType dataType
+ ) throws IOException;
+
/** Returns a Dataset over an input slice */
CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
throws IOException;
diff --git a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java
index 0dfb0960cebbe..9dd1d667eac28 100644
--- a/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java
+++ b/x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java
@@ -21,6 +21,7 @@ public class DatasetUtilsImpl implements DatasetUtils {
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();
private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
+ private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides();
static DatasetUtils getInstance() {
return INSTANCE;
@@ -40,6 +41,27 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d
}
}
+ static CuVSMatrix fromMemorySegment(
+ MemorySegment memorySegment,
+ int size,
+ int dimensions,
+ int rowStride,
+ int columnStride,
+ CuVSMatrix.DataType dataType
+ ) {
+ try {
+ return (CuVSMatrix) createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType);
+ } catch (Throwable e) {
+ if (e instanceof Error err) {
+ throw err;
+ } else if (e instanceof RuntimeException re) {
+ throw re;
+ } else {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
private DatasetUtilsImpl() {}
@Override
@@ -50,6 +72,21 @@ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
}
+ @Override
+ public CuVSMatrix fromInput(
+ MemorySegmentAccessInput input,
+ int numVectors,
+ int dims,
+ int rowStride,
+ int columnStride,
+ CuVSMatrix.DataType dataType
+ ) throws IOException {
+ if (numVectors < 0 || dims < 0) {
+ throwIllegalArgumentException(numVectors, dims);
+ }
+ return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType);
+ }
+
@Override
public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
throws IOException {
@@ -76,6 +113,25 @@ private static CuVSMatrix createCuVSMatrix(
return fromMemorySegment(ms, numVectors, dims, dataType);
}
+ private static CuVSMatrix createCuVSMatrix(
+ MemorySegmentAccessInput input,
+ long pos,
+ long len,
+ int numVectors,
+ int dims,
+ int rowStride,
+ int columnStride,
+ CuVSMatrix.DataType dataType
+ ) throws IOException {
+ MemorySegment ms = input.segmentSliceOrNull(pos, len);
+ assert ms != null;
+ final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES;
+ if (((long) numVectors * rowStride * byteSize) > ms.byteSize()) {
+ throwIllegalArgumentException(ms, numVectors, dims);
+ }
+ return fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType);
+ }
+
static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
throw new IllegalArgumentException(s);