Skip to content

Commit f5cc9de

Browse files
authored
Handling 'out' parameter in math mod (#876)
1 parent b41271f commit f5cc9de

File tree

1 file changed

+8
-13
lines changed

1 file changed

+8
-13
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141

4242

4343
from dpnp.dpnp_algo import *
44-
from dpnp.dparray import dparray
4544
from dpnp.dpnp_utils import *
4645

4746
import dpnp
@@ -196,16 +195,15 @@ def add(x1, x2, dtype=None, out=None, where=True, **kwargs):
196195
pass
197196
elif x2_desc and x2_desc.ndim == 0:
198197
pass
199-
elif out is not None and not isinstance(out, dparray):
200-
pass
201198
elif dtype is not None:
202199
pass
203200
elif out is not None:
204201
pass
205202
elif not where:
206203
pass
207204
else:
208-
return dpnp_add(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
205+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
206+
return dpnp_add(x1_desc, x2_desc, dtype, out_desc, where).get_pyobj()
209207

210208
return call_origin(numpy.add, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
211209

@@ -750,8 +748,6 @@ def floor_divide(x1, x2, dtype=None, out=None, where=True, **kwargs):
750748
pass
751749
elif x1_desc and x2_desc and x1_desc.shape != x2_desc.shape:
752750
pass
753-
elif out is not None and not isinstance(out, dparray):
754-
pass
755751
elif dtype is not None:
756752
pass
757753
elif out is not None:
@@ -761,7 +757,8 @@ def floor_divide(x1, x2, dtype=None, out=None, where=True, **kwargs):
761757
elif x1_is_scalar and x2_desc.ndim > 1:
762758
pass
763759
else:
764-
return dpnp_floor_divide(x1_desc, x2_desc, out=out, where=where, dtype=dtype)
760+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
761+
return dpnp_floor_divide(x1_desc, x2_desc, dtype, out_desc, where)
765762

766763
return call_origin(numpy.floor_divide, x1, x2, out=out, where=where, dtype=dtype, **kwargs)
767764

@@ -854,16 +851,15 @@ def fmod(x1, x2, dtype=None, out=None, where=True, **kwargs):
854851
pass
855852
elif x2_desc and x2.ndim == 0:
856853
pass
857-
elif out is not None and not isinstance(out, dparray):
858-
pass
859854
elif dtype is not None:
860855
pass
861856
elif out is not None:
862857
pass
863858
elif not where:
864859
pass
865860
else:
866-
return dpnp_fmod(x1_desc, x2_desc, dtype=dtype, out=out, where=where).get_pyobj()
861+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
862+
return dpnp_fmod(x1_desc, x2_desc, dtype, out_desc, where).get_pyobj()
867863

868864
return call_origin(numpy.fmod, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
869865

@@ -1438,8 +1434,6 @@ def remainder(x1, x2, out=None, where=True, dtype=None, **kwargs):
14381434
pass
14391435
elif x1_desc and x2_desc and x1_desc.shape != x2_desc.shape:
14401436
pass
1441-
elif out is not None and not isinstance(out, dparray):
1442-
pass
14431437
elif dtype is not None:
14441438
pass
14451439
elif out is not None:
@@ -1449,7 +1443,8 @@ def remainder(x1, x2, out=None, where=True, dtype=None, **kwargs):
14491443
elif x1_is_scalar and x2_desc.ndim > 1:
14501444
pass
14511445
else:
1452-
return dpnp_remainder(x1_desc, x2_desc, out=out, where=where, dtype=dtype)
1446+
out_desc = dpnp.get_dpnp_descriptor(out) if out is not None else None
1447+
return dpnp_remainder(x1_desc, x2_desc, dtype, out_desc, where)
14531448

14541449
return call_origin(numpy.remainder, x1, x2, out=out, where=where, dtype=dtype, **kwargs)
14551450

0 commit comments

Comments
 (0)