Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions gradle/verification-metadata.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1199,9 +1199,9 @@
<sha256 value="64fab42f17bf8e0efb193dd34da716ef7abb7515234036119df1776b808dc066" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.nvidia.cuvs" name="cuvs-java" version="25.10.0">
<artifact name="cuvs-java-25.10.0.jar">
<sha256 value="2d4cca3b6b6c7c4d3c1a2f57c00cb55f7a45699536ad356699f32bcea7714539" origin="Generated by Gradle"/>
<component group="com.nvidia.cuvs" name="cuvs-java" version="25.10.0-815d86dd">
<artifact name="cuvs-java-25.10.0-815d86dd.jar">
<sha256 value="b15a5f63b7cc2349444ee5470dfe7a316ccd11b6fcc4be3dd4b11aaeb2ae65fe" origin="Generated by Gradle"/>
</artifact>
</component>
<component group="com.perforce" name="p4java" version="2015.2.1365273">
Expand Down
2 changes: 1 addition & 1 deletion x-pack/plugin/gpu/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repositories {
dependencies {
compileOnly project(path: xpackModule('core'))
compileOnly project(':server')
implementation('com.nvidia.cuvs:cuvs-java:25.10.0') {
implementation('com.nvidia.cuvs:cuvs-java:25.10.0-815d86dd') {
changing = true // Ensure that we get updates even when the version number doesn't change. We can remove this once things stabilize
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably don't need this anymore as we have the -hashsuffix now, right @brianseeders ?

}
testImplementation(testArtifact(project(xpackModule('core'))))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(long l, long l1,
return delegate.newHostMatrixBuilder(l, l1, dataType);
}

@Override
public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(
long size,
long columns,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) {
return delegate.newHostMatrixBuilder(size, columns, rowStride, columnStride, dataType);
}

@Override
public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
CuVSResources cuVSResources,
Expand Down Expand Up @@ -65,6 +76,11 @@ public MethodHandle newNativeMatrixBuilder() {
return delegate.newNativeMatrixBuilder();
}

@Override
public MethodHandle newNativeMatrixBuilderWithStrides() {
return delegate.newNativeMatrixBuilderWithStrides();
}

@Override
public CuVSMatrix newMatrixFromArray(float[][] floats) {
return delegate.newMatrixFromArray(floats);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,15 @@ static DatasetUtils getInstance() {
/** Returns a Dataset over the vectors of type {@code dataType} in the input. */
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;

CuVSMatrix fromInput(
MemorySegmentAccessInput input,
int numVectors,
int dims,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) throws IOException;

/** Returns a Dataset over an input slice */
CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
throws IOException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ public class DatasetUtilsImpl implements DatasetUtils {
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();

private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides();

static DatasetUtils getInstance() {
return INSTANCE;
Expand All @@ -40,6 +41,27 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d
}
}

static CuVSMatrix fromMemorySegment(
MemorySegment memorySegment,
int size,
int dimensions,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) {
try {
return (CuVSMatrix) createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we don't need columnStride and it's not really supported (yet) in cuvs-java, I think you might want to pass -1 here like you did in 14088bd. So you can remove all the int columnStride params from all these functions and just have rowStride.
But up to you (I'm fine either way).

} catch (Throwable e) {
if (e instanceof Error err) {
throw err;
} else if (e instanceof RuntimeException re) {
throw re;
} else {
throw new RuntimeException(e);
}
}
}

private DatasetUtilsImpl() {}

@Override
Expand All @@ -50,6 +72,21 @@ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
}

@Override
public CuVSMatrix fromInput(
MemorySegmentAccessInput input,
int numVectors,
int dims,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) throws IOException {
if (numVectors < 0 || dims < 0) {
throwIllegalArgumentException(numVectors, dims);
}
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType);
}

@Override
public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
throws IOException {
Expand All @@ -76,6 +113,25 @@ private static CuVSMatrix createCuVSMatrix(
return fromMemorySegment(ms, numVectors, dims, dataType);
}

private static CuVSMatrix createCuVSMatrix(
MemorySegmentAccessInput input,
long pos,
long len,
int numVectors,
int dims,
int rowStride,
int columnStride,
CuVSMatrix.DataType dataType
) throws IOException {
MemorySegment ms = input.segmentSliceOrNull(pos, len);
assert ms != null;
final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES;
if (((long) numVectors * rowStride * byteSize) > ms.byteSize()) {
throwIllegalArgumentException(ms, numVectors, dims);
}
return fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType);
}

static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
throw new IllegalArgumentException(s);
Expand Down