Skip to content

Commit a390ca7

Browse files
authored
Add converting usm_array to ndarray in fallback (#921)
* Add converting usm_array to ndarray to fallback * Replace _to_numpy with copy_to_host * Add copy of result from host in call_origin
1 parent 66504da commit a390ca7

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

dpnp/dpnp_iface_arraycreation.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -196,11 +196,7 @@ def array(x1, dtype=None, copy=True, order='C', subok=False, ndmin=0, like=None)
196196
if not dpnp.is_type_supported(dtype) and dtype is not None:
197197
pass
198198
elif config.__DPNP_OUTPUT_DPCTL__:
199-
# TODO this is workaround becasue
200-
# usm_array has no element wise assignment (aka []) and
201-
# has no "flat" property and
202-
# "usm_data.copy_from_host" doesn't work with diffrent datatypes
203-
return numpy.array(x1, dtype=dtype, copy=copy, order=order, subok=subok, ndmin=ndmin)
199+
return call_origin(numpy.array, x1, dtype=dtype, copy=copy, order=order, subok=subok, ndmin=ndmin)
204200
elif subok is not False:
205201
pass
206202
elif copy is not True:

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ cdef ERROR_PREFIX = "DPNP error:"
6868

6969
def convert_item(item):
7070
if getattr(item, "__sycl_usm_array_interface__", False):
71-
item_converted = dpnp.asnumpy(item)
71+
item_converted = dpnp.asnumpy(item)
7272
elif getattr(item, "__array_interface__", False): # detect if it is a container (TODO any better way?)
7373
mod_name = getattr(item, "__module__", 'none')
7474
if (mod_name != 'numpy'):
@@ -91,7 +91,18 @@ def convert_list_args(input_list):
9191
result_list.append(item_converted)
9292

9393
return result_list
94-
94+
95+
96+
def copy_from_origin(dst, src):
97+
"""Copy origin result to output result."""
98+
if config.__DPNP_OUTPUT_DPCTL__ and hasattr(dst, "__sycl_usm_array_interface__"):
99+
if src.size:
100+
dst.usm_data.copy_from_host(src.reshape(-1).view("|u1"))
101+
else:
102+
for i in range(dst.size):
103+
dst.flat[i] = src.item(i)
104+
105+
95106
def call_origin(function, *args, **kwargs):
96107
"""
97108
Call fallback function for unsupported cases
@@ -138,8 +149,7 @@ def call_origin(function, *args, **kwargs):
138149
else:
139150
result = kwargs_out
140151

141-
for i in range(result.size):
142-
result.flat[i] = result_origin.item(i)
152+
copy_from_origin(result, result_origin)
143153

144154
elif isinstance(result, tuple):
145155
# convert tuple(fallback_array) to tuple(result_array)
@@ -148,8 +158,7 @@ def call_origin(function, *args, **kwargs):
148158
res = res_origin
149159
if isinstance(res_origin, numpy.ndarray):
150160
res = create_output_container(res_origin.shape, res_origin.dtype)
151-
for i in range(res.size):
152-
res.flat[i] = res_origin.item(i)
161+
copy_from_origin(res, res_origin)
153162
result_list.append(res)
154163

155164
result = tuple(result_list)

0 commit comments

Comments
 (0)