Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion numba_cuda/numba/cuda/np/numpy_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,31 @@ def _from_datetime_dtype(dtype):
raise errors.NumbaNotImplementedError(dtype)


def _dtype_cache_key(dtype):
"""
Create a cache key for dtype that includes the isalignedstruct attribute.

NumPy's dtype hashing and comparison mechanisms do not consider the
isalignedstruct field when determining dtype equality. This means that
two dtypes with the same structure but different alignment settings would
be treated as identical by NumPy's default comparison, leading to incorrect
caching behavior. To ensure correct caching of from_dtype results, we
extend the cache key to explicitly include the isalignedstruct attribute.
"""
return (
dtype,
getattr(dtype, "isalignedstruct", None),
)


@functools.lru_cache
def from_dtype(dtype):
def _from_dtype_impl(dtype_cached):
"""
Return a Numba Type instance corresponding to the given Numpy *dtype*.
NumbaNotImplementedError is raised on unsupported Numpy dtypes.
"""
dtype = dtype_cached[0]

if type(dtype) is type and issubclass(dtype, np.generic):
dtype = np.dtype(dtype)
elif getattr(dtype, "fields", None) is not None:
Expand All @@ -148,6 +167,10 @@ def from_dtype(dtype):
raise errors.NumbaNotImplementedError(dtype)


def from_dtype(dtype):
return _from_dtype_impl(_dtype_cache_key(dtype))


_as_dtype_letters = {
types.NPDatetime: "M8",
types.NPTimedelta: "m8",
Expand Down
23 changes: 23 additions & 0 deletions numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,29 @@ def f(x, y):
# Ensure that synchronize was not called
mock_sync.assert_not_called()

def test_from_dtype_caching_distiguish_custom_dtypes(self):
f16x2 = np.dtype([("x", np.float16), ("y", np.float16)])
f16x2_aligned = np.dtype(
[("x", np.float16), ("y", np.float16)], align=True
)

@cuda.jit
def f1(input, output):
pass

@cuda.jit
def f2(input, output):
pass

arr0 = cuda.to_device(np.zeros((1,), dtype=f16x2))
arr1 = cuda.to_device(np.zeros((1,), dtype=f16x2_aligned))
output = cuda.to_device(np.zeros((2,), dtype=np.int64))

f1[1, 1](arr0, output)
f2[1, 1](arr1, output)

assert f1.signatures[0] != f2.signatures[0]


if __name__ == "__main__":
unittest.main()
Loading