Skip to content

Commit de6e7b3

Browse files
authored
call_origin() nested lists support and inplace arguments (initial) (#873)
1 parent 00289ce commit de6e7b3

File tree

3 files changed

+60
-18
lines changed

3 files changed

+60
-18
lines changed

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -82,30 +82,66 @@ __all__ = [
8282

8383
cdef ERROR_PREFIX = "DPNP error:"
8484

85+
def convert_item(item):
86+
if getattr(item, "__sycl_usm_array_interface__", False):
87+
item_converted = dpnp.asnumpy(item)
88+
elif getattr(item, "__array_interface__", False): # detect if it is a container (TODO any better way?)
89+
mod_name = getattr(item, "__module__", 'none')
90+
if (mod_name != 'numpy'):
91+
item_converted = dpnp.asnumpy(item)
92+
else:
93+
item_converted = item
94+
elif isinstance(item, list):
95+
item_converted = convert_list_args(item)
96+
elif isinstance(item, tuple):
97+
item_converted = tuple(convert_list_args(item))
98+
else:
99+
item_converted = item
100+
101+
return item_converted
102+
103+
def convert_list_args(input_list):
104+
result_list = []
105+
for item in input_list:
106+
item_converted = convert_item(item)
107+
result_list.append(item_converted)
85108

109+
return result_list
110+
86111
def call_origin(function, *args, **kwargs):
87112
"""
88113
Call fallback function for unsupported cases
89114
"""
90115

91-
# print(f"DPNP call_origin(): Fallback called. \n\t function={function}, \n\t args={args}, \n\t kwargs={kwargs}")
116+
dpnp_inplace = kwargs.pop("dpnp_inplace", False)
117+
# print(f"DPNP call_origin(): Fallback called. \n\t function={function}, \n\t args={args}, \n\t kwargs={kwargs}, \n\t dpnp_inplace={dpnp_inplace}")
92118

93119
kwargs_out = kwargs.get("out", None)
94120
if (kwargs_out is not None):
95-
kwargs["out"] = dpnp.asnumpy(kwargs_out) if isinstance(kwargs_out, dparray) else kwargs_out
121+
if isinstance(kwargs_out, numpy.ndarray):
122+
kwargs["out"] = kwargs_out
123+
else:
124+
kwargs["out"] = dpnp.asnumpy(kwargs_out)
96125

97-
args_new = []
98-
for arg in args:
99-
argx = dpnp.asnumpy(arg) if isinstance(arg, dparray) else arg
100-
args_new.append(argx)
126+
if dpnp_inplace:
127+
# TODO replacement of foreign containers is still needed
128+
args_new = args
129+
else:
130+
args_new_list = []
131+
for arg in args:
132+
argx = convert_item(arg)
133+
args_new_list.append(argx)
134+
args_new = tuple(args_new_list)
101135

102136
kwargs_new = {}
103137
for key, kwarg in kwargs.items():
104-
kwargx = dpnp.asnumpy(kwarg) if isinstance(kwarg, dparray) else kwarg
138+
kwargx = convert_item(kwarg)
105139
kwargs_new[key] = kwargx
106140

107-
# TODO need to put dparray memory into NumPy call
141+
# print(f"DPNP call_origin(): bakend called. \n\t function={function}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
142+
# TODO need to put array memory into NumPy call
108143
result_origin = function(*args_new, **kwargs_new)
144+
# print(f"DPNP call_origin(): result from backend. \n\t result_origin={result_origin}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
109145
result = result_origin
110146
if isinstance(result, numpy.ndarray):
111147
if (kwargs_out is None):
@@ -119,7 +155,8 @@ def call_origin(function, *args, **kwargs):
119155
result = kwargs_out
120156

121157
for i in range(result.size):
122-
result._setitem_scalar(i, result_origin.item(i))
158+
result.flat[i] = result_origin.item(i)
159+
123160
elif isinstance(result, tuple):
124161
# convert tuple(ndarray) to tuple(dparray)
125162
result_list = []
@@ -128,7 +165,7 @@ def call_origin(function, *args, **kwargs):
128165
if isinstance(res_origin, numpy.ndarray):
129166
res = dparray(res_origin.shape, dtype=res_origin.dtype)
130167
for i in range(res.size):
131-
res._setitem_scalar(i, res_origin.item(i))
168+
res.flat[i] = res_origin.item(i)
132169
result_list.append(res)
133170

134171
result = tuple(result_list)

dpnp/random/dpnp_iface_random.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1307,10 +1307,11 @@ def shuffle(x1):
13071307
if not dpnp.is_type_supported(x1_desc.dtype):
13081308
pass
13091309
else:
1310-
result = dpnp_rng_shuffle(x1_desc).get_pyobj()
1311-
return result
1310+
dpnp_rng_shuffle(x1_desc).get_pyobj()
1311+
return
13121312

1313-
return call_origin(numpy.random.shuffle, x1)
1313+
call_origin(numpy.random.shuffle, x1, dpnp_inplace=True)
1314+
return
13141315

13151316

13161317
def seed(seed=None):

tests/test_random.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -945,9 +945,11 @@ def test_shuffle(self, dtype):
945945
input_x_int64 = dpnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], dtype=dpnp.int64)
946946
input_x = dpnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9, 0], dtype=dtype)
947947
dpnp.random.seed(seed)
948-
desired_x = dpnp.random.shuffle(input_x_int64).astype(dtype)
948+
dpnp.random.shuffle(input_x_int64) # inplace
949+
desired_x = input_x_int64.astype(dtype)
949950
dpnp.random.seed(seed)
950-
actual_x = dpnp.random.shuffle(input_x)
951+
dpnp.random.shuffle(input_x) # inplace
952+
actual_x = input_x
951953
assert_array_equal(actual_x, desired_x)
952954

953955
@pytest.mark.parametrize("dtype", [dpnp.float32, dpnp.float64, dpnp.int32, dpnp.int64],
@@ -957,7 +959,8 @@ def test_no_miss_numbers(self, dtype):
957959
input_x = dpnp.asarray([5, 4, 0, 7, 6, 1, 8, 3, 2, 9], dtype=dtype)
958960
desired_x = dpnp.sort(input_x)
959961
dpnp.random.seed(seed)
960-
output_x = dpnp.random.shuffle(input_x)
962+
dpnp.random.shuffle(input_x) # inplace
963+
output_x = input_x
961964
actual_x = dpnp.sort(output_x)
962965
assert_array_equal(actual_x, desired_x)
963966

@@ -1002,12 +1005,13 @@ def test_shuffle1(self, conv):
10021005
dpnp.random.seed(seed)
10031006
list_1d = [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
10041007
dpnp_1d = dpnp.array(list_1d)
1005-
dpnp_desired_1d = dpnp.random.shuffle(dpnp_1d)
1008+
dpnp.random.shuffle(dpnp_1d) # inplace
1009+
dpnp_desired_1d = dpnp_1d
10061010
desired_1d = [i for i in dpnp_desired_1d]
10071011

10081012
dpnp.random.seed(seed)
10091013
alist = conv(list_1d)
1010-
dpnp.random.shuffle(alist)
1014+
dpnp.random.shuffle(alist) # inplace
10111015
actual = alist
10121016
desired = conv(desired_1d)
10131017
assert_array_equal(actual, desired)

0 commit comments

Comments
 (0)