Skip to content

Commit 02664ac

Browse files
committed
caching fix
1 parent 6525c5a commit 02664ac

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

numba_cuda/numba/cuda/np/numpy_support.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,31 @@ def _from_datetime_dtype(dtype):
121121
raise errors.NumbaNotImplementedError(dtype)
122122

123123

124+
def _dtype_cache_key(dtype):
125+
"""
126+
Create a cache key for dtype that includes the isalignedstruct attribute.
127+
128+
NumPy's dtype hashing and comparison mechanisms do not consider the
129+
isalignedstruct field when determining dtype equality. This means that
130+
two dtypes with the same structure but different alignment settings would
131+
be treated as identical by NumPy's default comparison, leading to incorrect
132+
caching behavior. To ensure correct caching of from_dtype results, we
133+
extend the cache key to explicitly include the isalignedstruct attribute.
134+
"""
135+
return (
136+
dtype,
137+
getattr(dtype, "isalignedstruct", None),
138+
)
139+
140+
124141
@functools.lru_cache
125-
def from_dtype(dtype):
142+
def _from_dtype_impl(dtype_cached):
126143
"""
127144
Return a Numba Type instance corresponding to the given Numpy *dtype*.
128145
NumbaNotImplementedError is raised on unsupported Numpy dtypes.
129146
"""
147+
dtype = dtype_cached[0]
148+
130149
if type(dtype) is type and issubclass(dtype, np.generic):
131150
dtype = np.dtype(dtype)
132151
elif getattr(dtype, "fields", None) is not None:
@@ -148,6 +167,10 @@ def from_dtype(dtype):
148167
raise errors.NumbaNotImplementedError(dtype)
149168

150169

170+
def from_dtype(dtype):
171+
return _from_dtype_impl(_dtype_cache_key(dtype))
172+
173+
151174
_as_dtype_letters = {
152175
types.NPDatetime: "M8",
153176
types.NPTimedelta: "m8",

numba_cuda/numba/cuda/tests/cudapy/test_cuda_array_interface.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,29 @@ def f(x, y):
450450
# Ensure that synchronize was not called
451451
mock_sync.assert_not_called()
452452

453+
def test_from_dtype_caching_distiguish_custom_dtypes(self):
454+
f16x2 = np.dtype([("x", np.float16), ("y", np.float16)])
455+
f16x2_aligned = np.dtype(
456+
[("x", np.float16), ("y", np.float16)], align=True
457+
)
458+
459+
@cuda.jit
460+
def f1(input, output):
461+
pass
462+
463+
@cuda.jit
464+
def f2(input, output):
465+
pass
466+
467+
arr0 = cuda.to_device(np.zeros((1,), dtype=f16x2))
468+
arr1 = cuda.to_device(np.zeros((1,), dtype=f16x2_aligned))
469+
output = cuda.to_device(np.zeros((2,), dtype=np.int64))
470+
471+
f1[1, 1](arr0, output)
472+
f2[1, 1](arr1, output)
473+
474+
assert f1.signatures[0] != f2.signatures[0]
475+
453476

454477
if __name__ == "__main__":
455478
unittest.main()

0 commit comments

Comments
 (0)