Skip to content

Commit 6b68466

Browse files
dpctl.tensor.from_numpy should not try creating USM ndarray for 64-bit fp on HW that has no such support
1 parent 8e1132c commit 6b68466

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

dpctl/tensor/_copy_utils.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,19 @@ def _copy_to_numpy(ary):
5959
def _copy_from_numpy(np_ary, usm_type="device", sycl_queue=None):
6060
"Copies numpy array `np_ary` into a new usm_ndarray"
6161
# This may peform a copy to meet stated requirements
62-
Xnp = np.require(np_ary, requirements=["A", "O", "C", "E"])
63-
if sycl_queue:
64-
ctor_kwargs = {"queue": sycl_queue}
62+
Xnp = np.require(np_ary, requirements=["A", "E"])
63+
alloc_q = normalize_queue_device(sycl_queue=sycl_queue, device=None)
64+
dt = Xnp.dtype
65+
if dt.char in "dD" and alloc_q.sycl_device.has_aspect_fp64 is False:
66+
Xusm_dtype = (
67+
np.dtype("float32") if dt.char == "d" else np.dtype("complex64")
68+
)
6569
else:
66-
ctor_kwargs = dict()
67-
Xusm = dpt.usm_ndarray(
68-
Xnp.shape,
69-
dtype=Xnp.dtype,
70-
buffer=usm_type,
71-
buffer_ctor_kwargs=ctor_kwargs,
70+
Xusm_dtype = dt
71+
Xusm = dpt.empty(
72+
Xnp.shape, dtype=Xusm_dtype, usm_type=usm_type, sycl_queue=sycl_queue
7273
)
73-
Xusm.usm_data.copy_from_host(Xnp.reshape((-1)).view("u1"))
74+
_copy_from_numpy_into(Xusm, Xnp)
7475
return Xusm
7576

7677

0 commit comments

Comments
 (0)