Skip to content

Commit dcd40aa

Browse files
committed
Add a special handling for fix() and round() ufuncs
1 parent 8a27ab9 commit dcd40aa

File tree

2 files changed

+28
-11
lines changed

2 files changed

+28
-11
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -199,19 +199,25 @@ def __call__(
199199
if dtype is not None:
200200
x_usm = dpt.astype(x_usm, dtype, copy=False)
201201

202-
if isinstance(out, tuple):
203-
if len(out) != self.nout:
204-
raise ValueError(
205-
"'out' tuple must have exactly one entry per ufunc output"
206-
)
207-
out = out[0]
202+
out = self._unpack_out_kw(out)
208203
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
209204

210205
res_usm = super().__call__(x_usm, out=out_usm, order=order)
211206
if out is not None and isinstance(out, dpnp_array):
212207
return out
213208
return dpnp_array._create_from_usm_ndarray(res_usm)
214209

210+
def _unpack_out_kw(self, out):
211+
"""Unpack `out` keyword if passed as a tuple."""
212+
213+
if isinstance(out, tuple):
214+
if len(out) != self.nout:
215+
raise ValueError(
216+
"'out' tuple must have exactly one entry per ufunc output"
217+
)
218+
return out[0]
219+
return out
220+
215221

216222
class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
217223
"""
@@ -819,15 +825,22 @@ def __call__(self, x, /, out=None, *, order="K"):
819825
pass # pass to raise error in main implementation
820826
elif dpnp.issubdtype(x.dtype, dpnp.inexact):
821827
pass # for inexact types, pass to calculate in the backend
822-
elif out is not None and not dpnp.is_supported_array_type(out):
828+
elif not (
829+
out is None
830+
or isinstance(out, tuple)
831+
or dpnp.is_supported_array_type(out)
832+
):
823833
pass # pass to raise error in main implementation
824-
elif out is not None and out.dtype != x.dtype:
834+
elif not (
835+
out is None or isinstance(out, tuple) or out.dtype == x.dtype
836+
):
825837
# passing will raise an error but with incorrect needed dtype
826838
raise ValueError(
827839
f"Output array of type {x.dtype} is needed, got {out.dtype}"
828840
)
829841
else:
830842
# for exact types, return the input
843+
out = self._unpack_out_kw(out)
831844
if out is None:
832845
return dpnp.copy(x, order=order)
833846

@@ -932,6 +945,7 @@ def __init__(
932945
def __call__(self, x, /, decimals=0, out=None, *, dtype=None):
933946
if decimals != 0:
934947
x_usm = dpnp.get_usm_ndarray(x)
948+
out = self._unpack_out_kw(out)
935949
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
936950

937951
if dpnp.issubdtype(x_usm.dtype, dpnp.integer):

dpnp/tests/test_mathematical.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,14 +2010,17 @@ def test_out_dtype(self, func):
20102010
fn(*args, out=out, dtype="f4")
20112011

20122012
@pytest.mark.parametrize("xp", [numpy, dpnp])
2013-
@pytest.mark.parametrize("func", ["abs", "add", "frexp"])
2013+
@pytest.mark.parametrize("func", ["abs", "fix", "round", "add", "frexp"])
20142014
def test_out_wrong_tuple_len(self, xp, func):
2015+
if func == "round" and xp is numpy:
2016+
pytest.skip("numpy.round(x, out=(...)) is not supported")
2017+
20152018
x = xp.array([1, 2, 3])
20162019

20172020
fn = getattr(xp, func)
2018-
args = [x] * fn.nin
2021+
args = [x] * getattr(fn, "nin", getattr(dpnp, func).nin)
20192022

2020-
nout = fn.nout
2023+
nout = getattr(fn, "nout", getattr(dpnp, func).nout)
20212024
outs = [(), tuple(range(nout + 1))]
20222025
if nout > 1:
20232026
outs.append(tuple(range(nout - 1)))

0 commit comments

Comments
 (0)