-
Notifications
You must be signed in to change notification settings - Fork 25.5k
Cuvs snapshot update #136057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Cuvs snapshot update #136057
Changes from all commits
2b65b21
ee109f4
fb4f37f
0108d23
5d2e48e
59c2032
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,6 +21,7 @@ public class DatasetUtilsImpl implements DatasetUtils { | |
private static final DatasetUtils INSTANCE = new DatasetUtilsImpl(); | ||
|
||
private static final MethodHandle createDataset$mh = CuVSProvider.provider().newNativeMatrixBuilder(); | ||
private static final MethodHandle createDatasetWithStrides$mh = CuVSProvider.provider().newNativeMatrixBuilderWithStrides(); | ||
|
||
static DatasetUtils getInstance() { | ||
return INSTANCE; | ||
|
@@ -40,6 +41,27 @@ static CuVSMatrix fromMemorySegment(MemorySegment memorySegment, int size, int d | |
} | ||
} | ||
|
||
static CuVSMatrix fromMemorySegment( | ||
MemorySegment memorySegment, | ||
int size, | ||
int dimensions, | ||
int rowStride, | ||
int columnStride, | ||
CuVSMatrix.DataType dataType | ||
) { | ||
try { | ||
return (CuVSMatrix) createDatasetWithStrides$mh.invokeExact(memorySegment, size, dimensions, rowStride, columnStride, dataType); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As we don't need columnStride and it's not really supported (yet) in cuvs-java, I think you might want to pass -1 here like you did in 14088bd. So you can remove all the |
||
} catch (Throwable e) { | ||
if (e instanceof Error err) { | ||
throw err; | ||
} else if (e instanceof RuntimeException re) { | ||
throw re; | ||
} else { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} | ||
|
||
private DatasetUtilsImpl() {} | ||
|
||
@Override | ||
|
@@ -50,6 +72,21 @@ public CuVSMatrix fromInput(MemorySegmentAccessInput input, int numVectors, int | |
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, dataType); | ||
} | ||
|
||
@Override | ||
public CuVSMatrix fromInput( | ||
MemorySegmentAccessInput input, | ||
int numVectors, | ||
int dims, | ||
int rowStride, | ||
int columnStride, | ||
CuVSMatrix.DataType dataType | ||
) throws IOException { | ||
if (numVectors < 0 || dims < 0) { | ||
throwIllegalArgumentException(numVectors, dims); | ||
} | ||
return createCuVSMatrix(input, 0L, input.length(), numVectors, dims, rowStride, columnStride, dataType); | ||
} | ||
|
||
@Override | ||
public CuVSMatrix fromSlice(MemorySegmentAccessInput input, long pos, long len, int numVectors, int dims, CuVSMatrix.DataType dataType) | ||
throws IOException { | ||
|
@@ -76,6 +113,25 @@ private static CuVSMatrix createCuVSMatrix( | |
return fromMemorySegment(ms, numVectors, dims, dataType); | ||
} | ||
|
||
private static CuVSMatrix createCuVSMatrix( | ||
MemorySegmentAccessInput input, | ||
long pos, | ||
long len, | ||
int numVectors, | ||
int dims, | ||
int rowStride, | ||
int columnStride, | ||
CuVSMatrix.DataType dataType | ||
) throws IOException { | ||
MemorySegment ms = input.segmentSliceOrNull(pos, len); | ||
assert ms != null; | ||
final int byteSize = dataType == CuVSMatrix.DataType.FLOAT ? Float.BYTES : Byte.BYTES; | ||
if (((long) numVectors * rowStride * byteSize) > ms.byteSize()) { | ||
throwIllegalArgumentException(ms, numVectors, dims); | ||
} | ||
return fromMemorySegment(ms, numVectors, dims, rowStride, columnStride, dataType); | ||
} | ||
|
||
static void throwIllegalArgumentException(MemorySegment ms, int numVectors, int dims) { | ||
var s = "segment of size [" + ms.byteSize() + "] too small for expected " + numVectors + " float vectors of " + dims + " dims"; | ||
throw new IllegalArgumentException(s); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We probably don't need this anymore as we have the
-hashsuffix
now, right @brianseeders ?