Skip to content

Commit e3dbb0e

Browse files
kaeun97gmarkall
andauthored
fix: normalize numpy integer types to python int to prevent overflow errors (#774)
Addresses this bug: #623 --------- Co-authored-by: Graham Markall <535640+gmarkall@users.noreply.github.com>
1 parent a03472f commit e3dbb0e

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

numba_cuda/numba/cuda/cudadrv/devicearray.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,15 @@ def __init__(self, shape, strides, dtype, stream=0, gpu_data=None):
8282
if isinstance(shape, int):
8383
shape = (shape,)
8484
else:
85-
shape = tuple(shape)
85+
normalized_shape = []
86+
for s in shape:
87+
if not isinstance(s, (int, np.integer)):
88+
raise TypeError(
89+
f"shape elements must be integers, got {type(s).__name__}"
90+
)
91+
normalized_shape.append(int(s))
92+
shape = tuple(normalized_shape)
93+
8694
if isinstance(strides, int):
8795
strides = (strides,)
8896
else:

numba_cuda/numba/cuda/cudadrv/dummyarray.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,8 +249,17 @@ class Array:
249249
@functools.cache
250250
def from_desc(cls, offset, shape, strides, itemsize):
251251
dims = []
252+
shape = tuple(
253+
int(s) if isinstance(s, (int, np.integer)) else s for s in shape
254+
)
255+
strides = tuple(
256+
int(s) if isinstance(s, (int, np.integer)) else s for s in strides
257+
)
258+
offset = (
259+
int(offset) if isinstance(offset, (int, np.integer)) else offset
260+
)
252261
for ashape, astride in zip(shape, strides):
253-
if not isinstance(ashape, (int, np.integer)):
262+
if not isinstance(ashape, int):
254263
raise TypeError("all elements of shape must be ints")
255264
dim = Dim(
256265
offset, offset + ashape * astride, ashape, astride, single=False

numba_cuda/numba/cuda/tests/nocuda/test_dummyarray.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,5 +448,67 @@ def test_empty_array_typeof(self):
448448
)
449449

450450

451+
@skip_on_cudasim("Tests internals of the CUDA driver device array")
452+
class TestNumpyIntegerTypes(unittest.TestCase):
453+
def test_from_desc_with_numpy_integer_types(self):
454+
# test that various numpy integer types in shape/strides are normalised to Python int
455+
test_cases = [
456+
# (shape, strides, description)
457+
(
458+
(np.int32(10), np.int32(20)),
459+
(np.int32(80), np.int32(4)),
460+
"np.int32",
461+
),
462+
(
463+
(np.int64(15), np.int64(25)),
464+
(np.int64(100), np.int64(4)),
465+
"np.int64",
466+
),
467+
(
468+
(10, np.int32(20), np.int64(30)),
469+
(np.int32(2400), 120, np.int64(4)),
470+
"mixed types",
471+
),
472+
((np.intp(8), np.intp(12)), (np.intp(48), np.intp(4)), "np.intp"),
473+
]
474+
475+
itemsize = 4
476+
offset = 0
477+
478+
for shape, strides, description in test_cases:
479+
with self.subTest(description=description):
480+
arr = Array.from_desc(offset, shape, strides, itemsize)
481+
482+
expected_shape = tuple(int(s) for s in shape)
483+
expected_strides = tuple(int(s) for s in strides)
484+
self.assertEqual(arr.shape, expected_shape)
485+
self.assertEqual(arr.strides, expected_strides)
486+
487+
for s in arr.shape:
488+
self.assertIsInstance(s, int)
489+
self.assertNotIsInstance(s, np.integer)
490+
491+
for stride in arr.strides:
492+
self.assertIsInstance(stride, int)
493+
self.assertNotIsInstance(stride, np.integer)
494+
495+
def test_from_desc_tuple_from_numpy_array(self):
496+
# reference: https://github.com/NVIDIA/numba-cuda/issues/623
497+
shape_array = np.array([50, 100], dtype=np.int32)
498+
shape_tuple = tuple(shape_array) # Preserves np.int32!
499+
500+
self.assertIsInstance(shape_tuple[0], np.int32)
501+
502+
itemsize = 4
503+
strides_tuple = (itemsize * shape_tuple[1], itemsize)
504+
505+
arr = Array.from_desc(0, shape_tuple, strides_tuple, itemsize)
506+
507+
self.assertEqual(arr.shape, (50, 100))
508+
for s in arr.shape:
509+
self.assertIsInstance(s, int)
510+
self.assertNotIsInstance(s, np.integer)
511+
512+
451513
if __name__ == "__main__":
452514
unittest.main()

0 commit comments

Comments
 (0)