@@ -199,19 +199,25 @@ def __call__(
199199 if dtype is not None :
200200 x_usm = dpt .astype (x_usm , dtype , copy = False )
201201
202- if isinstance (out , tuple ):
203- if len (out ) != self .nout :
204- raise ValueError (
205- "'out' tuple must have exactly one entry per ufunc output"
206- )
207- out = out [0 ]
202+ out = self ._unpack_out_kw (out )
208203 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
209204
210205 res_usm = super ().__call__ (x_usm , out = out_usm , order = order )
211206 if out is not None and isinstance (out , dpnp_array ):
212207 return out
213208 return dpnp_array ._create_from_usm_ndarray (res_usm )
214209
210+ def _unpack_out_kw (self , out ):
211+ """Unpack `out` keyword if passed as a tuple."""
212+
213+ if isinstance (out , tuple ):
214+ if len (out ) != self .nout :
215+ raise ValueError (
216+ "'out' tuple must have exactly one entry per ufunc output"
217+ )
218+ return out [0 ]
219+ return out
220+
215221
216222class DPNPUnaryTwoOutputsFunc (UnaryElementwiseFunc ):
217223 """
@@ -819,15 +825,22 @@ def __call__(self, x, /, out=None, *, order="K"):
819825 pass # pass to raise error in main implementation
820826 elif dpnp .issubdtype (x .dtype , dpnp .inexact ):
821827 pass # for inexact types, pass to calculate in the backend
822- elif out is not None and not dpnp .is_supported_array_type (out ):
828+ elif not (
829+ out is None
830+ or isinstance (out , tuple )
831+ or dpnp .is_supported_array_type (out )
832+ ):
823833 pass # pass to raise error in main implementation
824- elif out is not None and out .dtype != x .dtype :
834+ elif not (
835+ out is None or isinstance (out , tuple ) or out .dtype == x .dtype
836+ ):
825837 # passing will raise an error but with incorrect needed dtype
826838 raise ValueError (
827839 f"Output array of type { x .dtype } is needed, got { out .dtype } "
828840 )
829841 else :
830842 # for exact types, return the input
843+ out = self ._unpack_out_kw (out )
831844 if out is None :
832845 return dpnp .copy (x , order = order )
833846
@@ -932,6 +945,7 @@ def __init__(
932945 def __call__ (self , x , / , decimals = 0 , out = None , * , dtype = None ):
933946 if decimals != 0 :
934947 x_usm = dpnp .get_usm_ndarray (x )
948+ out = self ._unpack_out_kw (out )
935949 out_usm = None if out is None else dpnp .get_usm_ndarray (out )
936950
937951 if dpnp .issubdtype (x_usm .dtype , dpnp .integer ):
0 commit comments