Skip to content

Commit 049d1a5

Browse files
Fixed typo in function dtype_to_typenum (#438)
This changes enables one to use NumPy's dtype in dtype keyword of the constructor. dpt.usm_ndarray(X.shape, dtype=X.dtype) Added a test case to cover this use-case
1 parent 3c6e3af commit 049d1a5

File tree

2 files changed

+4
-3
lines changed

2 files changed

+4
-3
lines changed

dpctl/tensor/_types.pxi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ cdef int type_bytesize(int typenum):
8888
return type_to_bytesize[typenum]
8989

9090

91-
cdef int typenum_from_format(str s):
91+
cdef int typenum_from_format(str s) except *:
9292
"""
9393
Internal utility to convert string describing type format
9494
@@ -106,7 +106,7 @@ cdef int typenum_from_format(str s):
106106
return dt.num
107107

108108

109-
cdef int dtype_to_typenum(dtype):
109+
cdef int dtype_to_typenum(dtype) except *:
110110
if isinstance(dtype, str):
111111
return typenum_from_format(dtype)
112112
elif isinstance(dtype, bytes):
@@ -121,6 +121,6 @@ cdef int dtype_to_typenum(dtype):
121121
obj = obj[1]
122122
if not isinstance(obj, str):
123123
return -1
124-
return typenum_from_format(dtype)
124+
return typenum_from_format(obj)
125125
else:
126126
return -1

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def test_allocate_usm_ndarray(shape, usm_type):
7171
"f8",
7272
"c8",
7373
"c16",
74+
np.dtype("d"),
7475
],
7576
)
7677
def test_dtypes(dtype):

0 commit comments

Comments
 (0)