Skip to content

Commit ee109f4

Browse files
Add method with strides
1 parent 2b65b21 commit ee109f4

File tree

3 files changed

+80
-0
lines changed

3 files changed

+80
-0
lines changed

x-pack/plugin/gpu/src/internalClusterTest/java/org/elasticsearch/plugin/gpu/CuVSProviderDelegate.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,17 @@ public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(long l, long l1,
3838
return delegate.newHostMatrixBuilder(l, l1, dataType);
3939
}
4040

41+
@Override
42+
public CuVSMatrix.Builder<CuVSHostMatrix> newHostMatrixBuilder(
43+
long size,
44+
long columns,
45+
int rowStride,
46+
int columnStride,
47+
CuVSMatrix.DataType dataType
48+
) {
49+
return delegate.newHostMatrixBuilder(size, columns, rowStride, columnStride, dataType);
50+
}
51+
4152
@Override
4253
public CuVSMatrix.Builder<CuVSDeviceMatrix> newDeviceMatrixBuilder(
4354
CuVSResources cuVSResources,
@@ -65,6 +76,11 @@ public MethodHandle newNativeMatrixBuilder() {
6576
return delegate.newNativeMatrixBuilder();
6677
}
6778

79+
@Override
80+
public MethodHandle newNativeMatrixBuilderWithStrides() {
81+
return delegate.newNativeMatrixBuilderWithStrides();
82+
}
83+
6884
@Override
6985
public CuVSMatrix newMatrixFromArray(float[][] floats) {
7086
return delegate.newMatrixFromArray(floats);

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtils.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ static DatasetUtils getInstance() {
2222
/** Returns a Dataset over the vectors of type {@code dataType} in the input. */
2323
CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int dims, CuVSMatrix.DataType dataType) throws IOException;
2424

25+
CuVSMatrix fromInput(
26+
MemorySegmentAccessInput input,
27+
int numVectors,
28+
int dims,
29+
int rowStride,
30+
int columnStride,
31+
CuVSMatrix.DataType dataType
32+
) throws IOException;
33+
2534
/** Returns a Dataset over an input slice */
2635
CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
2736
throws IOException;

x-pack/plugin/gpu/src/main/java/org/elasticsearch/xpack/gpu/codec/DatasetUtilsImpl.java

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ public class DatasetUtilsImpl implements DatasetUtils {
2121
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl();
2222

2323
private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder();
24+
private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides();
2425

2526
static DatasetUtils getInstance() {
2627
return INSTANCE;
@@ -40,6 +41,27 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d
4041
}
4142
}
4243

44+
static CuVSMatrix fromMemorySegment(
45+
MemorySegment memorySegment,
46+
int size,
47+
int dimensions,
48+
int rowStride,
49+
int columnStride,
50+
CuVSMatrix.DataType dataType
51+
) {
52+
try {
53+
return (CuVSMatrix) createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType);
54+
} catch (Throwable e) {
55+
if (e instanceof Error err) {
56+
throw err;
57+
} else if (e instanceof RuntimeException re) {
58+
throw re;
59+
} else {
60+
throw new RuntimeException(e);
61+
}
62+
}
63+
}
64+
4365
private DatasetUtilsImpl() {}
4466

4567
@Override
@@ -50,6 +72,21 @@ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int
5072
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType);
5173
}
5274

75+
@Override
76+
public CuVSMatrix fromInput(
77+
MemorySegmentAccessInput input,
78+
int numVectors,
79+
int dims,
80+
int rowStride,
81+
int columnStride,
82+
CuVSMatrix.DataType dataType
83+
) throws IOException {
84+
if (numVectors < 0 || dims < 0) {
85+
throwIllegalArgumentException(numVectors, dims);
86+
}
87+
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType);
88+
}
89+
5390
@Override
5491
public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType)
5592
throws IOException {
@@ -76,6 +113,24 @@ private static CuVSMatrix createCuVSMatrix(
76113
return fromMemorySegment(ms, numVectors, dims, dataType);
77114
}
78115

116+
private static CuVSMatrix createCuVSMatrix(
117+
MemorySegmentAccessInput input,
118+
long pos,
119+
long len,
120+
int numVectors,
121+
int dims,
122+
int rowStride,
123+
int columnStride,
124+
CuVSMatrix.DataType dataType
125+
) throws IOException {
126+
MemorySegment ms = input.segmentSliceOrNull(pos, len);
127+
assert ms != null;
128+
if (((long) numVectors * rowStride) > ms.byteSize()) {
129+
throwIllegalArgumentException(ms, numVectors, dims);
130+
}
131+
return fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType);
132+
}
133+
79134
static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) {
80135
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims";
81136
throw new IllegalArgumentException(s);

0 commit comments

Comments
 (0)