Skip to content

Commit 029b59a

Browse files
Diptorup Debmingjie-intel
authored andcommitted
Fix typo in the keyword argumnet when constructing a tensor.
- numba-dpex's usm_ndarray type relies on dpctl.tensor to determine the default dtype for an array. In the constructor for the usm_ndarray type, there was a typo that was causing an exception in dpjit when dpnp.empty or other constructor was called without a dtype. - Adds a unit test case.
1 parent c5c2962 commit 029b59a

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

numba_dpex/core/types/usm_ndarray_type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def __init__(
7474

7575
if not dtype:
7676
dummy_tensor = dpctl.tensor.empty(
77-
sh=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
77+
shape=1, order=layout, usm_type=usm_type, sycl_queue=self.queue
7878
)
7979
# convert dpnp type to numba/numpy type
8080
_dtype = dummy_tensor.dtype

numba_dpex/tests/dpjit_tests/dpnp/test_dpnp_empty.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,29 @@ def func1(shape):
4747
)
4848
else:
4949
c.sycl_device.filter_string == dpctl.SyclDevice().filter_string
50+
51+
52+
@pytest.mark.parametrize("shape", shapes)
53+
def test_dpnp_empty_default_dtype(shape):
54+
@dpjit
55+
def func1(shape):
56+
c = dpnp.empty(shape=shape)
57+
return c
58+
59+
try:
60+
c = func1(shape)
61+
except Exception:
62+
pytest.fail("Calling dpnp.empty inside dpjit failed")
63+
64+
if len(c.shape) == 1:
65+
assert c.shape[0] == shape
66+
else:
67+
assert c.shape == shape
68+
69+
dummy_tensor = dpctl.tensor.empty(shape=1)
70+
71+
assert c.dtype == dummy_tensor.dtype
72+
73+
dummy_tensor = dpctl.tensor.empty(shape)
74+
75+
assert c.dtype == dummy_tensor.dtype

0 commit comments

Comments
 (0)