Skip to content

Commit 486b00d

Browse files
authored
Accept tuple value for out keyword in ufuncs (#2664)
This PR adds support for the `out` keyword to accept a tuple passed to a ufunc. The documentation of every function is updated to reflect that change.
1 parent 7eed7da commit 486b00d

File tree

10 files changed

+395
-152
lines changed

10 files changed

+395
-152
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ Also, that release drops support for Python 3.9, making Python 3.10 the minimum
3333
* Added the missing positional-only and keyword-only parameter markers to bring the ufunc signatures into alignment with NumPy [#2660](https://github.com/IntelPython/dpnp/pull/2660)
3434
* Redesigned `dpnp.modf` function to be a part of `ufunc` and `vm` pybind11 extensions [#2654](https://github.com/IntelPython/dpnp/pull/2654)
3535
* Refactored `dpnp.fft` and `dpnp.random` submodules by removing wildcard imports and defining explicit public exports [#2649](https://github.com/IntelPython/dpnp/pull/2649)
36+
* Added support for the `out` keyword to accept a tuple, bringing ufunc signatures into alignment with those in NumPy [#2664](https://github.com/IntelPython/dpnp/pull/2664)
3637

3738
### Deprecated
3839

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -199,13 +199,25 @@ def __call__(
199199
if dtype is not None:
200200
x_usm = dpt.astype(x_usm, dtype, copy=False)
201201

202+
out = self._unpack_out_kw(out)
202203
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
203-
res_usm = super().__call__(x_usm, out=out_usm, order=order)
204204

205+
res_usm = super().__call__(x_usm, out=out_usm, order=order)
205206
if out is not None and isinstance(out, dpnp_array):
206207
return out
207208
return dpnp_array._create_from_usm_ndarray(res_usm)
208209

210+
def _unpack_out_kw(self, out):
211+
"""Unpack `out` keyword if passed as a tuple."""
212+
213+
if isinstance(out, tuple):
214+
if len(out) != self.nout:
215+
raise ValueError(
216+
"'out' tuple must have exactly one entry per ufunc output"
217+
)
218+
return out[0]
219+
return out
220+
209221

210222
class DPNPUnaryTwoOutputsFunc(UnaryElementwiseFunc):
211223
"""
@@ -361,7 +373,7 @@ def __call__(
361373
orig_out, out = list(out), list(out)
362374
res_dts = [res1_dt, res2_dt]
363375

364-
for i in range(2):
376+
for i in range(self.nout):
365377
if out[i] is None:
366378
continue
367379

@@ -419,7 +431,7 @@ def __call__(
419431
dep_evs = copy_ev
420432

421433
# Allocate a buffer for the output arrays if needed
422-
for i in range(2):
434+
for i in range(self.nout):
423435
if out[i] is None:
424436
res_dt = res_dts[i]
425437
if order == "K":
@@ -438,7 +450,7 @@ def __call__(
438450
)
439451
_manager.add_event_pair(ht_unary_ev, unary_ev)
440452

441-
for i in range(2):
453+
for i in range(self.nout):
442454
orig_res, res = orig_out[i], out[i]
443455
if not (orig_res is None or orig_res is res):
444456
# Copy the out data from temporary buffer to original memory
@@ -606,6 +618,13 @@ def __call__(
606618

607619
x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
608620
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
621+
622+
if isinstance(out, tuple):
623+
if len(out) != self.nout:
624+
raise ValueError(
625+
"'out' tuple must have exactly one entry per ufunc output"
626+
)
627+
out = out[0]
609628
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
610629

611630
if (
@@ -806,15 +825,22 @@ def __call__(self, x, /, out=None, *, order="K"):
806825
pass # pass to raise error in main implementation
807826
elif dpnp.issubdtype(x.dtype, dpnp.inexact):
808827
pass # for inexact types, pass to calculate in the backend
809-
elif out is not None and not dpnp.is_supported_array_type(out):
828+
elif not (
829+
out is None
830+
or isinstance(out, tuple)
831+
or dpnp.is_supported_array_type(out)
832+
):
810833
pass # pass to raise error in main implementation
811-
elif out is not None and out.dtype != x.dtype:
834+
elif not (
835+
out is None or isinstance(out, tuple) or out.dtype == x.dtype
836+
):
812837
# passing will raise an error but with incorrect needed dtype
813838
raise ValueError(
814839
f"Output array of type {x.dtype} is needed, got {out.dtype}"
815840
)
816841
else:
817842
# for exact types, return the input
843+
out = self._unpack_out_kw(out)
818844
if out is None:
819845
return dpnp.copy(x, order=order)
820846

@@ -919,6 +945,7 @@ def __init__(
919945
def __call__(self, x, /, decimals=0, out=None, *, dtype=None):
920946
if decimals != 0:
921947
x_usm = dpnp.get_usm_ndarray(x)
948+
out = self._unpack_out_kw(out)
922949
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
923950

924951
if dpnp.issubdtype(x_usm.dtype, dpnp.integer):

dpnp/dpnp_iface_bitwise.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,11 @@ def binary_repr(num, width=None):
144144
First input array, expected to have an integer or boolean data type.
145145
x2 : {dpnp.ndarray, usm_ndarray, scalar}
146146
Second input array, also expected to have an integer or boolean data type.
147-
out : {None, dpnp.ndarray, usm_ndarray}, optional
147+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
148148
Output array to populate.
149149
Array must have the correct shape and the expected data type.
150+
A tuple (possible only as a keyword argument) must have length equal to the
151+
number of outputs.
150152
151153
Default: ``None``.
152154
order : {None, "C", "F", "A", "K"}, optional
@@ -233,9 +235,11 @@ def binary_repr(num, width=None):
233235
----------
234236
x : {dpnp.ndarray, usm_ndarray}
235237
Input array, expected to have an integer data type.
236-
out : {None, dpnp.ndarray, usm_ndarray}, optional
238+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
237239
Output array to populate.
238240
Array must have the correct shape and the expected data type.
241+
A tuple (possible only as a keyword argument) must have length equal to the
242+
number of outputs.
239243
240244
Default: ``None``.
241245
order : {None, "C", "F", "A", "K"}, optional
@@ -290,9 +294,11 @@ def binary_repr(num, width=None):
290294
First input array, expected to have an integer or boolean data type.
291295
x2 : {dpnp.ndarray, usm_ndarray, scalar}
292296
Second input array, also expected to have an integer or boolean data type.
293-
out : {None, dpnp.ndarray, usm_ndarray}, optional
297+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
294298
Output array to populate.
295299
Array must have the correct shape and the expected data type.
300+
A tuple (possible only as a keyword argument) must have length equal to the
301+
number of outputs.
296302
297303
Default: ``None``.
298304
order : {None, "C", "F", "A", "K"}, optional
@@ -374,9 +380,11 @@ def binary_repr(num, width=None):
374380
First input array, expected to have an integer or boolean data type.
375381
x2 : {dpnp.ndarray, usm_ndarray, scalar}
376382
Second input array, also expected to have an integer or boolean data type.
377-
out : {None, dpnp.ndarray, usm_ndarray}, optional
383+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
378384
Output array to populate.
379385
Array must have the correct shape and the expected data type.
386+
A tuple (possible only as a keyword argument) must have length equal to the
387+
number of outputs.
380388
381389
Default: ``None``.
382390
order : {None, "C", "F", "A", "K"}, optional
@@ -460,9 +468,11 @@ def binary_repr(num, width=None):
460468
----------
461469
x : {dpnp.ndarray, usm_ndarray}
462470
Input array, expected to have an integer or boolean data type.
463-
out : {None, dpnp.ndarray, usm_ndarray}, optional
471+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
464472
Output array to populate.
465473
Array must have the correct shape and the expected data type.
474+
A tuple (possible only as a keyword argument) must have length equal to the
475+
number of outputs.
466476
467477
Default: ``None``.
468478
order : {None, "C", "F", "A", "K"}, optional
@@ -544,9 +554,11 @@ def binary_repr(num, width=None):
544554
x2 : {dpnp.ndarray, usm_ndarray, scalar}
545555
Second input array, also expected to have an integer data type.
546556
Each element must be greater than or equal to ``0``.
547-
out : {None, dpnp.ndarray, usm_ndarray}, optional
557+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
548558
Output array to populate.
549559
Array must have the correct shape and the expected data type.
560+
A tuple (possible only as a keyword argument) must have length equal to the
561+
number of outputs.
550562
551563
Default: ``None``.
552564
order : {None, "C", "F", "A", "K"}, optional
@@ -627,9 +639,11 @@ def binary_repr(num, width=None):
627639
x2 : {dpnp.ndarray, usm_ndarray, scalar}
628640
Second input array, also expected to have an integer data type.
629641
Each element must be greater than or equal to ``0``.
630-
out : {None, dpnp.ndarray, usm_ndarray}, optional
642+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
631643
Output array to populate.
632644
Array must have the correct shape and the expected data type.
645+
A tuple (possible only as a keyword argument) must have length equal to the
646+
number of outputs.
633647
634648
Default: ``None``.
635649
order : {None, "C", "F", "A", "K"}, optional

dpnp/dpnp_iface_logic.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -624,9 +624,11 @@ def array_equiv(a1, a2):
624624
First input array, may have any data type.
625625
x2 : {dpnp.ndarray, usm_ndarray, scalar}
626626
Second input array, also may have any data type.
627-
out : {None, dpnp.ndarray, usm_ndarray}, optional
627+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
628628
Output array to populate.
629-
Array have the correct shape and the expected data type.
629+
Array must have the correct shape and the expected data type.
630+
A tuple (possible only as a keyword argument) must have length equal to the
631+
number of outputs.
630632
631633
Default: ``None``.
632634
order : {None, "C", "F", "A", "K"}, optional
@@ -704,9 +706,11 @@ def array_equiv(a1, a2):
704706
First input array, may have any data type.
705707
x2 : {dpnp.ndarray, usm_ndarray, scalar}
706708
Second input array, also may have any data type.
707-
out : {None, dpnp.ndarray, usm_ndarray}, optional
709+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
708710
Output array to populate.
709711
Array must have the correct shape and the expected data type.
712+
A tuple (possible only as a keyword argument) must have length equal to the
713+
number of outputs.
710714
711715
Default: ``None``.
712716
order : {None, "C", "F", "A", "K"}, optional
@@ -778,9 +782,11 @@ def array_equiv(a1, a2):
778782
First input array, may have any data type.
779783
x2 : {dpnp.ndarray, usm_ndarray, scalar}
780784
Second input array, also may have any data type.
781-
out : {None, dpnp.ndarray, usm_ndarray}, optional
785+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
782786
Output array to populate.
783787
Array must have the correct shape and the expected data type.
788+
A tuple (possible only as a keyword argument) must have length equal to the
789+
number of outputs.
784790
785791
Default: ``None``.
786792
order : {None, "C", "F", "A", "K"}, optional
@@ -1066,9 +1072,11 @@ def iscomplexobj(x):
10661072
----------
10671073
x : {dpnp.ndarray, usm_ndarray}
10681074
Input array, may have any data type.
1069-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1075+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
10701076
Output array to populate.
10711077
Array must have the correct shape and the expected data type.
1078+
A tuple (possible only as a keyword argument) must have length equal to the
1079+
number of outputs.
10721080
10731081
Default: ``None``.
10741082
order : {None, "C", "F", "A", "K"}, optional
@@ -1198,9 +1206,11 @@ def isfortran(a):
11981206
----------
11991207
x : {dpnp.ndarray, usm_ndarray}
12001208
Input array, may have any data type.
1201-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1209+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
12021210
Output array to populate.
12031211
Array must have the correct shape and the expected data type.
1212+
A tuple (possible only as a keyword argument) must have length equal to the
1213+
number of outputs.
12041214
12051215
Default: ``None``.
12061216
order : {None, "C", "F", "A", "K"}, optional
@@ -1256,9 +1266,11 @@ def isfortran(a):
12561266
----------
12571267
x : {dpnp.ndarray, usm_ndarray}
12581268
Input array, may have any data type.
1259-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1269+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
12601270
Output array to populate.
12611271
Array must have the correct shape and the expected data type.
1272+
A tuple (possible only as a keyword argument) must have length equal to the
1273+
number of outputs.
12621274
12631275
Default: ``None``.
12641276
order : {None, "C", "F", "A", "K"}, optional
@@ -1593,9 +1605,11 @@ def isscalar(element):
15931605
First input array, may have any data type.
15941606
x2 : {dpnp.ndarray, usm_ndarray, scalar}
15951607
Second input array, also may have any data type.
1596-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1608+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
15971609
Output array to populate.
15981610
Array must have the correct shape and the expected data type.
1611+
A tuple (possible only as a keyword argument) must have length equal to the
1612+
number of outputs.
15991613
16001614
Default: ``None``.
16011615
order : {None, "C", "F", "A", "K"}, optional
@@ -1667,9 +1681,11 @@ def isscalar(element):
16671681
First input array, may have any data type.
16681682
x2 : {dpnp.ndarray, usm_ndarray, scalar}
16691683
Second input array, also may have any data type.
1670-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1684+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
16711685
Output array to populate.
16721686
Array must have the correct shape and the expected data type.
1687+
A tuple (possible only as a keyword argument) must have length equal to the
1688+
number of outputs.
16731689
16741690
Default: ``None``.
16751691
order : {None, "C", "F", "A", "K"}, optional
@@ -1740,9 +1756,11 @@ def isscalar(element):
17401756
First input array, may have any data type.
17411757
x2 : {dpnp.ndarray, usm_ndarray, scalar}
17421758
Second input array, also may have any data type.
1743-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1759+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
17441760
Output array to populate.
17451761
Array must have the correct shape and the expected data type.
1762+
A tuple (possible only as a keyword argument) must have length equal to the
1763+
number of outputs.
17461764
17471765
Default: ``None``.
17481766
order : {None, "C", "F", "A", "K"}, optional
@@ -1813,9 +1831,11 @@ def isscalar(element):
18131831
----------
18141832
x : {dpnp.ndarray, usm_ndarray}
18151833
Input array, may have any data type.
1816-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1834+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
18171835
Output array to populate.
18181836
Array must have the correct shape and the expected data type.
1837+
A tuple (possible only as a keyword argument) must have length equal to the
1838+
number of outputs.
18191839
18201840
Default: ``None``.
18211841
order : {None, "C", "F", "A", "K"}, optional
@@ -1879,9 +1899,11 @@ def isscalar(element):
18791899
First input array, may have any data type.
18801900
x2 : {dpnp.ndarray, usm_ndarray, scalar}
18811901
Second input array, also may have any data type.
1882-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1902+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
18831903
Output array to populate.
18841904
Array must have the correct shape and the expected data type.
1905+
A tuple (possible only as a keyword argument) must have length equal to the
1906+
number of outputs.
18851907
18861908
Default: ``None``.
18871909
order : {None, "C", "F", "A", "K"}, optional
@@ -1955,9 +1977,11 @@ def isscalar(element):
19551977
First input array, may have any data type.
19561978
x2 : {dpnp.ndarray, usm_ndarray, scalar}
19571979
Second input array, also may have any data type.
1958-
out : {None, dpnp.ndarray, usm_ndarray}, optional
1980+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
19591981
Output array to populate.
19601982
Array must have the correct shape and the expected data type.
1983+
A tuple (possible only as a keyword argument) must have length equal to the
1984+
number of outputs.
19611985
19621986
Default: ``None``.
19631987
order : {None, "C", "F", "A", "K"}, optional
@@ -2029,9 +2053,11 @@ def isscalar(element):
20292053
First input array, may have any data type.
20302054
x2 : {dpnp.ndarray, usm_ndarray, scalar}
20312055
Second input array, also may have any data type.
2032-
out : {None, dpnp.ndarray, usm_ndarray}, optional
2056+
out : {None, dpnp.ndarray, usm_ndarray, tuple of ndarray}, optional
20332057
Output array to populate.
20342058
Array must have the correct shape and the expected data type.
2059+
A tuple (possible only as a keyword argument) must have length equal to the
2060+
number of outputs.
20352061
20362062
Default: ``None``.
20372063
order : {None, "C", "F", "A", "K"}, optional

0 commit comments

Comments
 (0)