Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dpctl/tensor/_ctors.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,6 @@ def _asarray_from_usm_ndarray(
raise TypeError(
f"Expected dpctl.tensor.usm_ndarray, got {type(usm_ndary)}"
)
if dtype is None:
dtype = usm_ndary.dtype
if usm_type is None:
usm_type = usm_ndary.usm_type
if sycl_queue is not None:
Expand All @@ -122,6 +120,8 @@ def _asarray_from_usm_ndarray(
copy_q = normalize_queue_device(sycl_queue=sycl_queue, device=exec_q)
else:
copy_q = usm_ndary.sycl_queue
if dtype is None:
dtype = _map_to_device_dtype(usm_ndary.dtype, copy_q)
# Conditions for zero copy:
can_zero_copy = copy is not True
# dtype is unchanged
Expand Down
30 changes: 30 additions & 0 deletions dpctl/tests/test_tensor_asarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -623,3 +623,33 @@ def test_asarray_support_for_usm_ndarray_protocol(usm_type):
assert x.dtype == y3.dtype
assert y3.usm_data.reference_obj is None
assert dpt.all(x[dpt.newaxis, :] == y3)


@pytest.mark.parametrize("dt", [dpt.float16, dpt.float64, dpt.complex128])
def test_asarray_to_device_with_unsupported_dtype(dt):
aspect = "fp16" if dt == dpt.float16 else "fp64"
try:
d0 = dpctl.select_device_with_aspects(aspect)
except dpctl.SyclDeviceCreationError:
pytest.skip("No device with aspect for test")
d1 = None
try:
d1 = dpctl.select_device_with_aspects("cpu", excluded_aspects=[aspect])
except dpctl.SyclDeviceCreationError:
pass
try:
d1 = dpctl.select_device_with_aspects("gpu", excluded_aspects=[aspect])
except dpctl.SyclDeviceCreationError:
pass
try:
d1 = dpctl.select_device_with_aspects(
"accelerator", excluded_aspects=[aspect]
)
except dpctl.SyclDeviceCreationError:
pass
if d1 is None:
pytest.skip("No device with missing aspect for test")

x = dpt.ones(10, dtype=dt, device=d0)
y = dpt.asarray(x, device=d1)
assert y.sycl_device == d1
Loading