Skip to content

Commit e93e0f6

Browse files
authored
move 2 argument functions to descriptor part2 (#747)
1 parent ed5ba69 commit e93e0f6

File tree

3 files changed

+46
-38
lines changed

3 files changed

+46
-38
lines changed

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ cpdef dparray dpnp_add(object x1_obj, object x2_obj, object dtype=None, dparray
106106

107107

108108
cpdef dparray dpnp_arctan2(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
109-
return call_fptr_2in_1out(DPNP_FN_ARCTAN2, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
109+
return call_fptr_2in_1out_new(DPNP_FN_ARCTAN2, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
110110

111111

112112
cpdef dpnp_around(dparray x1, int decimals):
@@ -139,7 +139,7 @@ cpdef dparray dpnp_copysign(object x1_obj, object x2_obj, object dtype=None, dpa
139139

140140

141141
cpdef dparray dpnp_cross(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
142-
return call_fptr_2in_1out(DPNP_FN_CROSS, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
142+
return call_fptr_2in_1out_new(DPNP_FN_CROSS, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
143143

144144

145145
cpdef dparray dpnp_cumprod(dparray x1):
@@ -243,7 +243,7 @@ cpdef dparray dpnp_gradient(dparray y1, int dx=1):
243243

244244

245245
cpdef dparray dpnp_hypot(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
246-
return call_fptr_2in_1out(DPNP_FN_HYPOT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
246+
return call_fptr_2in_1out_new(DPNP_FN_HYPOT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
247247

248248

249249
cpdef dparray dpnp_maximum(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
@@ -379,7 +379,7 @@ cpdef dparray dpnp_sign(dparray x1):
379379

380380

381381
cpdef dparray dpnp_subtract(object x1_obj, object x2_obj, object dtype=None, dparray out=None, object where=True):
382-
return call_fptr_2in_1out(DPNP_FN_SUBTRACT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
382+
return call_fptr_2in_1out_new(DPNP_FN_SUBTRACT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
383383

384384

385385
cpdef dparray dpnp_sum(dparray input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):

dpnp/dpnp_iface_mathematical.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -404,14 +404,13 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
404404
405405
"""
406406

407-
if not use_origin_backend(x1):
408-
if not isinstance(x1, dparray):
409-
pass
410-
elif not isinstance(x2, dparray):
411-
pass
412-
elif x1.size != 3 or x2.size != 3:
407+
x1_desc = dpnp.get_dpnp_descriptor(x1)
408+
x2_desc = dpnp.get_dpnp_descriptor(x2)
409+
410+
if x1_desc and x2_desc:
411+
if x1_desc.size != 3 or x2_desc.size != 3:
413412
pass
414-
elif x1.shape != (3,) or x2.shape != (3,):
413+
elif x1_desc.shape != (3,) or x2_desc.shape != (3,):
415414
pass
416415
elif axisa != -1:
417416
pass
@@ -422,7 +421,7 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
422421
elif axis is not None:
423422
pass
424423
else:
425-
return dpnp_cross(x1, x2)
424+
return dpnp_cross(x1_desc, x2_desc)
426425

427426
return call_origin(numpy.cross, x1, x2, axisa, axisb, axisc, axis)
428427

@@ -1497,23 +1496,26 @@ def subtract(x1, x2, dtype=None, out=None, where=True, **kwargs):
14971496
[2, -4]
14981497
14991498
"""
1500-
x1_is_scalar, x2_is_scalar = dpnp.isscalar(x1), dpnp.isscalar(x2)
1501-
x1_is_dparray, x2_is_dparray = isinstance(x1, dparray), isinstance(x2, dparray)
15021499

1503-
if not use_origin_backend(x1) and not kwargs:
1504-
if not x1_is_dparray and not x1_is_scalar:
1500+
x1_is_scalar = dpnp.isscalar(x1)
1501+
x2_is_scalar = dpnp.isscalar(x2)
1502+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1503+
x2_desc = dpnp.get_dpnp_descriptor(x2)
1504+
1505+
if x1_desc and x2_desc and not kwargs:
1506+
if not x1_desc and not x1_is_scalar:
15051507
pass
1506-
elif not x2_is_dparray and not x2_is_scalar:
1508+
elif not x2_desc and not x2_is_scalar:
15071509
pass
15081510
elif x1_is_scalar and x2_is_scalar:
15091511
pass
1510-
elif x1_is_dparray and x1.ndim == 0:
1512+
elif x1_desc and x1_desc.ndim == 0:
15111513
pass
1512-
elif x1_is_dparray and x1.dtype == numpy.bool:
1514+
elif x1_desc and x1_desc.dtype == numpy.bool:
15131515
pass
1514-
elif x2_is_dparray and x2.ndim == 0:
1516+
elif x2_desc and x2_desc.ndim == 0:
15151517
pass
1516-
elif x2_is_dparray and x2.dtype == numpy.bool:
1518+
elif x2_desc and x2_desc.dtype == numpy.bool:
15171519
pass
15181520
elif dtype is not None:
15191521
pass
@@ -1522,7 +1524,7 @@ def subtract(x1, x2, dtype=None, out=None, where=True, **kwargs):
15221524
elif not where:
15231525
pass
15241526
else:
1525-
return dpnp_subtract(x1, x2, dtype=dtype, out=out, where=where)
1527+
return dpnp_subtract(x1_desc, x2_desc, dtype=dtype, out=out, where=where)
15261528

15271529
return call_origin(numpy.subtract, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
15281530

dpnp/dpnp_iface_trigonometric.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -354,19 +354,22 @@ def arctan2(x1, x2, dtype=None, out=None, where=True, **kwargs):
354354
[1.57079633, -1.57079633]
355355
356356
"""
357-
x1_is_scalar, x2_is_scalar = dpnp.isscalar(x1), dpnp.isscalar(x2)
358-
x1_is_dparray, x2_is_dparray = isinstance(x1, dparray), isinstance(x2, dparray)
359357

360-
if not use_origin_backend(x1) and not kwargs:
361-
if not x1_is_dparray and not x1_is_scalar:
358+
x1_is_scalar = dpnp.isscalar(x1)
359+
x2_is_scalar = dpnp.isscalar(x2)
360+
x1_desc = dpnp.get_dpnp_descriptor(x1)
361+
x2_desc = dpnp.get_dpnp_descriptor(x2)
362+
363+
if x1_desc and x2_desc and not kwargs:
364+
if not x1_desc and not x1_is_scalar:
362365
pass
363-
elif not x2_is_dparray and not x2_is_scalar:
366+
elif not x2_desc and not x2_is_scalar:
364367
pass
365368
elif x1_is_scalar and x2_is_scalar:
366369
pass
367-
elif x1_is_dparray and x1.ndim == 0:
370+
elif x1_desc and x1_desc.ndim == 0:
368371
pass
369-
elif x2_is_dparray and x2.ndim == 0:
372+
elif x2_desc and x2_desc.ndim == 0:
370373
pass
371374
elif out is not None and not isinstance(out, dparray):
372375
pass
@@ -377,7 +380,7 @@ def arctan2(x1, x2, dtype=None, out=None, where=True, **kwargs):
377380
elif not where:
378381
pass
379382
else:
380-
return dpnp_arctan2(x1, x2, dtype=dtype, out=out, where=where)
383+
return dpnp_arctan2(x1_desc, x2_desc, dtype=dtype, out=out, where=where)
381384

382385
return call_origin(numpy.arctan2, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
383386

@@ -624,19 +627,22 @@ def hypot(x1, x2, dtype=None, out=None, where=True, **kwargs):
624627
[5.0, 5.0, 5.0]
625628
626629
"""
627-
x1_is_scalar, x2_is_scalar = dpnp.isscalar(x1), dpnp.isscalar(x2)
628-
x1_is_dparray, x2_is_dparray = isinstance(x1, dparray), isinstance(x2, dparray)
629630

630-
if not use_origin_backend(x1) and not kwargs:
631-
if not x1_is_dparray and not x1_is_scalar:
631+
x1_is_scalar = dpnp.isscalar(x1)
632+
x2_is_scalar = dpnp.isscalar(x2)
633+
x1_desc = dpnp.get_dpnp_descriptor(x1)
634+
x2_desc = dpnp.get_dpnp_descriptor(x2)
635+
636+
if x1_desc and x2_desc and not kwargs:
637+
if not x1_desc and not x1_is_scalar:
632638
pass
633-
elif not x2_is_dparray and not x2_is_scalar:
639+
elif not x2_desc and not x2_is_scalar:
634640
pass
635641
elif x1_is_scalar and x2_is_scalar:
636642
pass
637-
elif x1_is_dparray and x1.ndim == 0:
643+
elif x1_desc and x1_desc.ndim == 0:
638644
pass
639-
elif x2_is_dparray and x2.ndim == 0:
645+
elif x2_desc and x2_desc.ndim == 0:
640646
pass
641647
elif out is not None and not isinstance(out, dparray):
642648
pass
@@ -647,7 +653,7 @@ def hypot(x1, x2, dtype=None, out=None, where=True, **kwargs):
647653
elif not where:
648654
pass
649655
else:
650-
return dpnp_hypot(x1, x2, dtype=dtype, out=out, where=where)
656+
return dpnp_hypot(x1_desc, x2_desc, dtype=dtype, out=out, where=where)
651657

652658
return call_origin(numpy.hypot, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
653659

0 commit comments

Comments
 (0)