Skip to content

Commit 47af401

Browse files
committed
Accept tuple value for out keyword passing to ufuncs
1 parent 4082593 commit 47af401

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,9 +199,15 @@ def __call__(
199199
if dtype is not None:
200200
x_usm = dpt.astype(x_usm, dtype, copy=False)
201201

202+
if isinstance(out, tuple):
203+
if len(out) != self.nout:
204+
raise ValueError(
205+
"'out' tuple must have exactly one entry per ufunc output"
206+
)
207+
out = out[0]
202208
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
203-
res_usm = super().__call__(x_usm, out=out_usm, order=order)
204209

210+
res_usm = super().__call__(x_usm, out=out_usm, order=order)
205211
if out is not None and isinstance(out, dpnp_array):
206212
return out
207213
return dpnp_array._create_from_usm_ndarray(res_usm)
@@ -361,7 +367,7 @@ def __call__(
361367
orig_out, out = list(out), list(out)
362368
res_dts = [res1_dt, res2_dt]
363369

364-
for i in range(2):
370+
for i in range(self.nout):
365371
if out[i] is None:
366372
continue
367373

@@ -419,7 +425,7 @@ def __call__(
419425
dep_evs = copy_ev
420426

421427
# Allocate a buffer for the output arrays if needed
422-
for i in range(2):
428+
for i in range(self.nout):
423429
if out[i] is None:
424430
res_dt = res_dts[i]
425431
if order == "K":
@@ -438,7 +444,7 @@ def __call__(
438444
)
439445
_manager.add_event_pair(ht_unary_ev, unary_ev)
440446

441-
for i in range(2):
447+
for i in range(self.nout):
442448
orig_res, res = orig_out[i], out[i]
443449
if not (orig_res is None or orig_res is res):
444450
# Copy the out data from temporary buffer to original memory
@@ -606,6 +612,13 @@ def __call__(
606612

607613
x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
608614
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
615+
616+
if isinstance(out, tuple):
617+
if len(out) != self.nout:
618+
raise ValueError(
619+
"'out' tuple must have exactly one entry per ufunc output"
620+
)
621+
out = out[0]
609622
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
610623

611624
if (

0 commit comments

Comments
 (0)