Skip to content

Commit fc18f6f

Browse files
Fix "Arrays are not equal" part 2 (#943)
* Fix "Arrays are not equal" part 2 * Change in place modification of only first argument in call_origin Co-authored-by: Alexander-Makaryev <[email protected]>
1 parent 03c1185 commit fc18f6f

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,7 @@ def place(x1, mask, vals):
353353
if x1_desc and mask_desc and vals_desc:
354354
return dpnp_place(x1_desc, mask, vals_desc)
355355

356-
return call_origin(numpy.place, x1, mask, vals)
356+
return call_origin(numpy.place, x1, mask, vals, dpnp_inplace=True)
357357

358358

359359
def put(x1, ind, v, mode='raise'):
@@ -406,7 +406,7 @@ def put_along_axis(x1, indices, values, axis):
406406
else:
407407
return dpnp_put_along_axis(x1_desc, indices_desc, values_desc, axis)
408408

409-
return call_origin(numpy.put_along_axis, x1, indices, values, axis)
409+
return call_origin(numpy.put_along_axis, x1, indices, values, axis, dpnp_inplace=True)
410410

411411

412412
def putmask(x1, mask, values):

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,15 +119,11 @@ def call_origin(function, *args, **kwargs):
119119
else:
120120
kwargs["out"] = dpnp.asnumpy(kwargs_out)
121121

122-
if dpnp_inplace:
123-
# TODO replacement of foreign containers is still needed
124-
args_new = args
125-
else:
126-
args_new_list = []
127-
for arg in args:
128-
argx = convert_item(arg)
129-
args_new_list.append(argx)
130-
args_new = tuple(args_new_list)
122+
args_new_list = []
123+
for arg in args:
124+
argx = convert_item(arg)
125+
args_new_list.append(argx)
126+
args_new = tuple(args_new_list)
131127

132128
kwargs_new = {}
133129
for key, kwarg in kwargs.items():
@@ -139,7 +135,17 @@ def call_origin(function, *args, **kwargs):
139135
result_origin = function(*args_new, **kwargs_new)
140136
# print(f"DPNP call_origin(): result from backend. \n\t result_origin={result_origin}, \n\t args_new={args_new}, \n\t kwargs_new={kwargs_new}, \n\t dpnp_inplace={dpnp_inplace}")
141137
result = result_origin
142-
if isinstance(result, numpy.ndarray):
138+
if dpnp_inplace:
139+
# enough to modify only first argument in place
140+
if args and args_new:
141+
arg, arg_new = args[0], args_new[0]
142+
if isinstance(arg_new, numpy.ndarray):
143+
copy_from_origin(arg, arg_new)
144+
elif isinstance(arg_new, list):
145+
for i, val in enumerate(arg_new):
146+
arg[i] = val
147+
148+
elif isinstance(result, numpy.ndarray):
143149
if (kwargs_out is None):
144150
result_dtype = result_origin.dtype
145151
kwargs_dtype = kwargs.get("dtype", None)

0 commit comments

Comments
 (0)