Skip to content

Commit dd2e942

Browse files
authored
Swallow DLPack conversion error (#825)
Close #824. The bug could have been caught by the existing datetime test, but 1. the test has a bug to cover the actual error 2. CuPy is not a test dependency As per #472 we also add CuPy to wheel test dependencies.
1 parent 9d039d7 commit dd2e942

File tree

3 files changed

+8
-5
lines changed

3 files changed

+8
-5
lines changed

numba_cuda/numba/cuda/tests/cudapy/test_datetime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,12 @@ def assign(out, arr):
7474
for i in range(arr.shape[0]):
7575
out[i] = arr[i]
7676

77-
# TODO: cupy doesn't allow passing the datetime64[D] array directly
77+
# CuPy doesn't allow constructing the datetime64[D] array directly
7878
arr = cp.array(
7979
np.arange("2005-02", "2006-02", dtype="datetime64[D]").view("int64")
8080
).view("datetime64[D]")
8181

82-
out = cp.empty_like(arr)
82+
out = cp.empty(arr.size, dtype="float64").view("datetime64[D]")
8383
assign[1, 1](out, arr)
8484

8585
self.assertPreciseEqual(arr.get(), out.get())

numba_cuda/numba/cuda/typing/typeof.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -361,7 +361,10 @@ def _typeof_cuda_array_interface_cached(
361361
def _typeof_dlpack(val, c):
362362
obj = getattr(val, "__self__", None)
363363
if obj is not None:
364-
smv = StridedMemoryView.from_dlpack(obj, stream_ptr=-1)
364+
try:
365+
smv = StridedMemoryView.from_dlpack(obj, stream_ptr=-1)
366+
except BufferError:
367+
return
365368

366369
smv_layout = smv._layout
367370
layout = (

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ test = [
7070
"ml_dtypes",
7171
"statistics",
7272
]
73-
test-cu12 = ["cuda-toolkit[curand]==12.*", { include-group = "test" }]
74-
test-cu13 = ["cuda-toolkit[curand]==13.*", { include-group = "test" }]
73+
test-cu12 = ["cuda-toolkit[curand,cublas]==12.*", { include-group = "test" }, "cupy-cuda12x !=14.0.0"]
74+
test-cu13 = ["cuda-toolkit[curand,cublas]==13.*", { include-group = "test" }, "cupy-cuda13x !=14.0.0"]
7575

7676
[project.urls]
7777
Homepage = "https://nvidia.github.io/numba-cuda/"

0 commit comments

Comments
 (0)