@@ -448,5 +448,67 @@ def test_empty_array_typeof(self):
448448 )
449449
450450
451+ @skip_on_cudasim ("Tests internals of the CUDA driver device array" )
452+ class TestNumpyIntegerTypes (unittest .TestCase ):
453+ def test_from_desc_with_numpy_integer_types (self ):
454+ # test that various numpy integer types in shape/strides are normalised to Python int
455+ test_cases = [
456+ # (shape, strides, description)
457+ (
458+ (np .int32 (10 ), np .int32 (20 )),
459+ (np .int32 (80 ), np .int32 (4 )),
460+ "np.int32" ,
461+ ),
462+ (
463+ (np .int64 (15 ), np .int64 (25 )),
464+ (np .int64 (100 ), np .int64 (4 )),
465+ "np.int64" ,
466+ ),
467+ (
468+ (10 , np .int32 (20 ), np .int64 (30 )),
469+ (np .int32 (2400 ), 120 , np .int64 (4 )),
470+ "mixed types" ,
471+ ),
472+ ((np .intp (8 ), np .intp (12 )), (np .intp (48 ), np .intp (4 )), "np.intp" ),
473+ ]
474+
475+ itemsize = 4
476+ offset = 0
477+
478+ for shape , strides , description in test_cases :
479+ with self .subTest (description = description ):
480+ arr = Array .from_desc (offset , shape , strides , itemsize )
481+
482+ expected_shape = tuple (int (s ) for s in shape )
483+ expected_strides = tuple (int (s ) for s in strides )
484+ self .assertEqual (arr .shape , expected_shape )
485+ self .assertEqual (arr .strides , expected_strides )
486+
487+ for s in arr .shape :
488+ self .assertIsInstance (s , int )
489+ self .assertNotIsInstance (s , np .integer )
490+
491+ for stride in arr .strides :
492+ self .assertIsInstance (stride , int )
493+ self .assertNotIsInstance (stride , np .integer )
494+
495+ def test_from_desc_tuple_from_numpy_array (self ):
496+ # reference: https://github.com/NVIDIA/numba-cuda/issues/623
497+ shape_array = np .array ([50 , 100 ], dtype = np .int32 )
498+ shape_tuple = tuple (shape_array ) # Preserves np.int32!
499+
500+ self .assertIsInstance (shape_tuple [0 ], np .int32 )
501+
502+ itemsize = 4
503+ strides_tuple = (itemsize * shape_tuple [1 ], itemsize )
504+
505+ arr = Array .from_desc (0 , shape_tuple , strides_tuple , itemsize )
506+
507+ self .assertEqual (arr .shape , (50 , 100 ))
508+ for s in arr .shape :
509+ self .assertIsInstance (s , int )
510+ self .assertNotIsInstance (s , np .integer )
511+
512+
451513if __name__ == "__main__" :
452514 unittest .main ()
0 commit comments