Skip to content

Commit 72ebdac

Browse files
Merge pull request #873 from IntelPython/type_dispatch_fix
Type dispatch fix
2 parents 5004aa1 + 1088074 commit 72ebdac

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

dpctl/tensor/libtensor/include/utils/type_dispatch.hpp

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,9 +228,22 @@ struct usm_ndarray_types
228228
else if (typenum == UAR_HALF_) {
229229
return static_cast<int>(typenum_t::HALF);
230230
}
231+
else if (typenum == UAR_INT || typenum == UAR_UINT) {
232+
switch (sizeof(int)) {
233+
case sizeof(std::int32_t):
234+
return ((typenum == UAR_INT)
235+
? static_cast<int>(typenum_t::INT32)
236+
: static_cast<int>(typenum_t::UINT32));
237+
case sizeof(std::int64_t):
238+
return ((typenum == UAR_INT)
239+
? static_cast<int>(typenum_t::INT64)
240+
: static_cast<int>(typenum_t::UINT64));
241+
default:
242+
throw_unrecognized_typenum_error(typenum);
243+
}
244+
}
231245
else {
232-
throw std::runtime_error("Unrecogized typenum " +
233-
std::to_string(typenum) + " encountered.");
246+
throw_unrecognized_typenum_error(typenum);
234247
}
235248
}
236249

@@ -286,6 +299,12 @@ struct usm_ndarray_types
286299

287300
return types;
288301
}
302+
303+
void throw_unrecognized_typenum_error(int typenum)
304+
{
305+
throw std::runtime_error("Unrecogized typenum " +
306+
std::to_string(typenum) + " encountered.");
307+
}
289308
};
290309

291310
} // namespace detail

dpctl/tests/test_usm_ndarray_ctor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,23 @@ def test_setitem_different_dtypes(src_dt, dst_dt):
708708
assert np.allclose(dpt.asnumpy(Z), np.tile(np.array([1, 0], Z.dtype), 10))
709709

710710

711+
def test_setitem_wingaps():
712+
try:
713+
q = dpctl.SyclQueue()
714+
except dpctl.SyclQueueCreationError:
715+
pytest.skip("Default queue could not be created")
716+
if np.dtype("intc").itemsize == np.dtype("int32").itemsize:
717+
dpt_dst = dpt.empty(4, dtype="int32", sycl_queue=q)
718+
np_src = np.arange(4, dtype="intc")
719+
dpt_dst[:] = np_src # should not raise exceptions
720+
assert np.array_equal(dpt.asnumpy(dpt_dst), np_src)
721+
if np.dtype("long").itemsize == np.dtype("longlong").itemsize:
722+
dpt_dst = dpt.empty(4, dtype="longlong", sycl_queue=q)
723+
np_src = np.arange(4, dtype="long")
724+
dpt_dst[:] = np_src # should not raise exceptions
725+
assert np.array_equal(dpt.asnumpy(dpt_dst), np_src)
726+
727+
711728
def test_shape_setter():
712729
def cc_strides(sh):
713730
return np.empty(sh, dtype="u1").strides

0 commit comments

Comments
 (0)