@@ -21,6 +21,7 @@ public class DatasetUtilsImpl implements DatasetUtils {
21
21
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl ();
22
22
23
23
private static final MethodHandle createDataset$mh = CuVSProvider .provider ().newNativeMatrixBuilder ();
24
+ private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider .provider ().newNativeMatrixBuilderWithStrides ();
24
25
25
26
static DatasetUtils getInstance () {
26
27
return INSTANCE ;
@@ -40,6 +41,27 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d
40
41
}
41
42
}
42
43
44
+ static CuVSMatrix fromMemorySegment (
45
+ MemorySegment memorySegment ,
46
+ int size ,
47
+ int dimensions ,
48
+ int rowStride ,
49
+ int columnStride ,
50
+ CuVSMatrix .DataType dataType
51
+ ) {
52
+ try {
53
+ return (CuVSMatrix ) createDatasetWithStrides$mh .invokeExact (memorySegment , size , dimensions , rowStride , columnStride , dataType );
54
+ } catch (Throwable e ) {
55
+ if (e instanceof Error err ) {
56
+ throw err ;
57
+ } else if (e instanceof RuntimeException re ) {
58
+ throw re ;
59
+ } else {
60
+ throw new RuntimeException (e );
61
+ }
62
+ }
63
+ }
64
+
43
65
private DatasetUtilsImpl () {}
44
66
45
67
@ Override
@@ -50,6 +72,21 @@ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int
50
72
return createCuVSMatrix (input , 0L , input .length (), numVectors , dims , dataType );
51
73
}
52
74
75
+ @ Override
76
+ public CuVSMatrix fromInput (
77
+ MemorySegmentAccessInput input ,
78
+ int numVectors ,
79
+ int dims ,
80
+ int rowStride ,
81
+ int columnStride ,
82
+ CuVSMatrix .DataType dataType
83
+ ) throws IOException {
84
+ if (numVectors < 0 || dims < 0 ) {
85
+ throwIllegalArgumentException (numVectors , dims );
86
+ }
87
+ return createCuVSMatrix (input , 0L , input .length (), numVectors , dims , rowStride , columnStride , dataType );
88
+ }
89
+
53
90
@ Override
54
91
public CuVSMatrix fromSlice (MemorySegmentAccessInput input , long pos , long len , int numVectors , int dims , CuVSMatrix .DataType dataType )
55
92
throws IOException {
@@ -76,6 +113,24 @@ private static CuVSMatrix createCuVSMatrix(
76
113
return fromMemorySegment (ms , numVectors , dims , dataType );
77
114
}
78
115
116
+ private static CuVSMatrix createCuVSMatrix (
117
+ MemorySegmentAccessInput input ,
118
+ long pos ,
119
+ long len ,
120
+ int numVectors ,
121
+ int dims ,
122
+ int rowStride ,
123
+ int columnStride ,
124
+ CuVSMatrix .DataType dataType
125
+ ) throws IOException {
126
+ MemorySegment ms = input .segmentSliceOrNull (pos , len );
127
+ assert ms != null ;
128
+ if (((long ) numVectors * rowStride ) > ms .byteSize ()) {
129
+ throwIllegalArgumentException (ms , numVectors , dims );
130
+ }
131
+ return fromMemorySegment (ms , numVectors , dims , rowStride , columnStride , dataType );
132
+ }
133
+
79
134
static void throwIllegalArgumentException (MemorySegment ms , int numVectors , int dims ) {
80
135
var s = "segment of size [" + ms .byteSize () + "] too small for expected " + numVectors + " float vectors of " + dims + " dims" ;
81
136
throw new IllegalArgumentException (s );
0 commit comments