Skip to content

Commit 4dc8704

Browse files
authored
Math funcs to desc (#769)
1 parent 87d8486 commit 4dc8704

File tree

5 files changed

+80
-87
lines changed

5 files changed

+80
-87
lines changed

dpnp/dpnp_algo/dpnp_algo_mathematical.pyx

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ ctypedef void(*ftpr_custom_trapz_2in_1out_with_2size_t)(void *, void * , void *
8080
ctypedef void(*ftpr_custom_around_1in_1out_t)(const void * , void * , const size_t, const int)
8181

8282

83-
cpdef dparray dpnp_absolute(dparray input):
83+
cpdef dparray dpnp_absolute(utils.dpnp_descriptor input):
8484
cdef dparray_shape_type input_shape = input.shape
8585
cdef size_t input_shape_size = input.ndim
8686

@@ -109,7 +109,7 @@ cpdef dparray dpnp_arctan2(object x1_obj, object x2_obj, object dtype=None, dpar
109109
return call_fptr_2in_1out(DPNP_FN_ARCTAN2, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
110110

111111

112-
cpdef dpnp_around(dparray x1, int decimals):
112+
cpdef dpnp_around(utils.dpnp_descriptor x1, int decimals):
113113

114114
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)
115115

@@ -166,7 +166,7 @@ cpdef dparray dpnp_cumsum(utils.dpnp_descriptor x1):
166166
return call_fptr_1in_1out(DPNP_FN_CUMSUM, x1, (x1.size,))
167167

168168

169-
cpdef dparray dpnp_diff(dparray input, int n):
169+
cpdef dparray dpnp_diff(object input, int n):
170170
if n == 0:
171171
return input
172172
if n < input.shape[-1]:
@@ -221,7 +221,7 @@ cpdef dparray dpnp_fmod(object x1_obj, object x2_obj, object dtype=None, dparray
221221
return call_fptr_2in_1out(DPNP_FN_FMOD, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
222222

223223

224-
cpdef dparray dpnp_gradient(dparray y1, int dx=1):
224+
cpdef dparray dpnp_gradient(object y1, int dx=1):
225225

226226
size = y1.size
227227

@@ -277,7 +277,7 @@ cpdef dparray dpnp_multiply(object x1_obj, object x2_obj, object dtype=None, dpa
277277
return call_fptr_2in_1out(DPNP_FN_MULTIPLY, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
278278

279279

280-
cpdef dparray dpnp_nancumprod(dparray x1):
280+
cpdef dparray dpnp_nancumprod(utils.dpnp_descriptor x1):
281281

282282
cur_x1 = dpnp.copy(x1)
283283

@@ -289,7 +289,7 @@ cpdef dparray dpnp_nancumprod(dparray x1):
289289
return dpnp_cumprod(x1_desc)
290290

291291

292-
cpdef dparray dpnp_nancumsum(dparray x1):
292+
cpdef dparray dpnp_nancumsum(utils.dpnp_descriptor x1):
293293

294294
cur_x1 = dpnp.copy(x1)
295295

@@ -301,7 +301,7 @@ cpdef dparray dpnp_nancumsum(dparray x1):
301301
return dpnp_cumsum(x1_desc)
302302

303303

304-
cpdef dpnp_nanprod(dparray x1):
304+
cpdef dpnp_nanprod(object x1):
305305
cdef dparray result = dparray(x1.shape, dtype=x1.dtype)
306306

307307
for i in range(result.size):
@@ -312,10 +312,11 @@ cpdef dpnp_nanprod(dparray x1):
312312
else:
313313
result._setitem_scalar(i, input_elem)
314314

315-
return dpnp_prod(result)
315+
result_desc = dpnp.get_dpnp_descriptor(result) # TODO remove it later
316+
return dpnp_prod(result_desc)
316317

317318

318-
cpdef dpnp_nansum(dparray x1):
319+
cpdef dpnp_nansum(object x1):
319320
cdef dparray result = dparray(x1.shape, dtype=x1.dtype)
320321

321322
for i in range(result.size):
@@ -328,7 +329,9 @@ cpdef dpnp_nansum(dparray x1):
328329

329330
# due to bug in dpnp_sum need this workaround
330331
# return dpnp_sum(result)
331-
sum_result = dpnp_sum(result)
332+
333+
result_desc = dpnp.get_dpnp_descriptor(result) # TODO remove it later
334+
sum_result = dpnp_sum(result_desc)
332335
return x1.dtype.type(sum_result[0])
333336

334337

@@ -340,7 +343,7 @@ cpdef dparray dpnp_power(object x1_obj, object x2_obj, object dtype=None, dparra
340343
return call_fptr_2in_1out(DPNP_FN_POWER, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
341344

342345

343-
cpdef dparray dpnp_prod(dparray input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
346+
cpdef dparray dpnp_prod(utils.dpnp_descriptor input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
344347
"""
345348
input:float64 : outout:float64 : name:prod
346349
input:float32 : outout:float32 : name:prod
@@ -384,7 +387,7 @@ cpdef dparray dpnp_subtract(object x1_obj, object x2_obj, object dtype=None, dpa
384387
return call_fptr_2in_1out(DPNP_FN_SUBTRACT, x1_obj, x2_obj, dtype=dtype, out=out, where=where)
385388

386389

387-
cpdef dparray dpnp_sum(dparray input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
390+
cpdef dparray dpnp_sum(utils.dpnp_descriptor input, object axis=None, object dtype=None, dparray out=None, cpp_bool keepdims=False, object initial=None, object where=True):
388391

389392
cdef dparray_shape_type input_shape = input.shape
390393
cdef DPNPFuncType input_c_type = dpnp_dtype_to_DPNPFuncType(input.dtype)
@@ -407,7 +410,7 @@ cpdef dparray dpnp_sum(dparray input, object axis=None, object dtype=None, dparr
407410
return result
408411

409412

410-
cpdef dpnp_trapz(dparray y1, dparray x1, double dx):
413+
cpdef dpnp_trapz(utils.dpnp_descriptor y1, dparray x1, double dx):
411414

412415
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(y1.dtype)
413416
cdef DPNPFuncType param2_type = dpnp_dtype_to_DPNPFuncType(x1.dtype)

dpnp/dpnp_algo/dpnp_algo_statistics.pyx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ cdef dparray call_fptr_custom_std_var_1in_1out(DPNPFuncName fptr_name, dparray a
8282
return result
8383

8484

85-
cpdef dpnp_average(dparray x1):
85+
cpdef dpnp_average(utils.dpnp_descriptor x1):
8686
array_sum = dpnp_sum(x1)
8787

8888
""" Numpy interface inconsistency """

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,12 @@
4242

4343
import collections
4444

45-
import numpy
46-
4745
from dpnp.dpnp_algo import *
4846
from dpnp.dparray import dparray
4947
from dpnp.dpnp_utils import *
48+
5049
import dpnp
50+
import numpy
5151

5252

5353
__all__ = [

dpnp/dpnp_iface_mathematical.py

Lines changed: 54 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -142,12 +142,14 @@ def absolute(x1, **kwargs):
142142
143143
"""
144144

145-
is_input_dparray = isinstance(x1, dparray)
146-
147-
if not use_origin_backend(x1) and is_input_dparray and x1.ndim != 0 and not kwargs:
148-
result = dpnp_absolute(x1)
145+
x1_desc = dpnp.get_dpnp_descriptor(x1)
146+
if x1_desc and not kwargs:
147+
if not x1_desc.ndim:
148+
pass
149+
else:
150+
result = dpnp_absolute(x1_desc)
149151

150-
return result
152+
return result
151153

152154
return call_origin(numpy.absolute, x1, **kwargs)
153155

@@ -236,15 +238,14 @@ def around(x1, decimals=0, out=None):
236238
237239
"""
238240

239-
if not use_origin_backend(x1):
240-
if not isinstance(x1, dparray):
241-
pass
242-
elif out is not None:
241+
x1_desc = dpnp.get_dpnp_descriptor(x1)
242+
if x1_desc:
243+
if out is not None:
243244
pass
244245
elif decimals != 0:
245246
pass
246247
else:
247-
return dpnp_around(x1, decimals)
248+
return dpnp_around(x1_desc, decimals)
248249

249250
return call_origin(numpy.around, x1, decimals=decimals, out=out)
250251

@@ -483,7 +484,7 @@ def cumsum(x1, **kwargs):
483484
return call_origin(numpy.cumsum, x1, **kwargs)
484485

485486

486-
def diff(input, n=1, axis=-1, prepend=None, append=None):
487+
def diff(x1, n=1, axis=-1, prepend=None, append=None):
487488
"""
488489
Calculate the n-th discrete difference along the given axis.
489490
@@ -496,10 +497,9 @@ def diff(input, n=1, axis=-1, prepend=None, append=None):
496497
Otherwise the function will be executed sequentially on CPU.
497498
"""
498499

499-
if not use_origin_backend(input):
500-
if not isinstance(input, dparray):
501-
pass
502-
elif not isinstance(n, int):
500+
x1_desc = dpnp.get_dpnp_descriptor(x1)
501+
if x1_desc:
502+
if not isinstance(n, int):
503503
pass
504504
elif n < 1:
505505
pass
@@ -510,9 +510,9 @@ def diff(input, n=1, axis=-1, prepend=None, append=None):
510510
elif append is not None:
511511
pass
512512
else:
513-
return dpnp_diff(input, n)
513+
return dpnp_diff(x1, n)
514514

515-
return call_origin(numpy.diff, input, n, axis, prepend, append)
515+
return call_origin(numpy.diff, x1, n, axis, prepend, append)
516516

517517

518518
def divide(x1, x2, dtype=None, out=None, where=True, **kwargs):
@@ -848,7 +848,7 @@ def fmod(x1, x2, dtype=None, out=None, where=True, **kwargs):
848848
return call_origin(numpy.fmod, x1, x2, dtype=dtype, out=out, where=where, **kwargs)
849849

850850

851-
def gradient(y1, *varargs, **kwargs):
851+
def gradient(x1, *varargs, **kwargs):
852852
"""
853853
Return the gradient of an array.
854854
@@ -874,20 +874,20 @@ def gradient(y1, *varargs, **kwargs):
874874
[0.5, 0.75, 1.25, 1.75, 2.25, 2.5]
875875
876876
"""
877-
if not use_origin_backend(y1) and not kwargs:
878-
if not isinstance(y1, dparray):
879-
pass
880-
elif len(varargs) > 1:
877+
878+
x1_desc = dpnp.get_dpnp_descriptor(x1)
879+
if x1_desc and not kwargs:
880+
if len(varargs) > 1:
881881
pass
882882
elif len(varargs) == 1 and not isinstance(varargs[0], int):
883883
pass
884884
else:
885885
if len(varargs) == 0:
886-
return dpnp_gradient(y1)
886+
return dpnp_gradient(x1)
887887

888-
return dpnp_gradient(y1, varargs[0])
888+
return dpnp_gradient(x1, varargs[0])
889889

890-
return call_origin(numpy.gradient, y1, *varargs, **kwargs)
890+
return call_origin(numpy.gradient, x1, *varargs, **kwargs)
891891

892892

893893
def maximum(x1, x2, dtype=None, out=None, where=True, **kwargs):
@@ -1136,11 +1136,9 @@ def nancumprod(x1, **kwargs):
11361136
11371137
"""
11381138

1139-
if not use_origin_backend(x1) and not kwargs:
1140-
if not isinstance(x1, dparray):
1141-
pass
1142-
else:
1143-
return dpnp_nancumprod(x1)
1139+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1140+
if x1_desc and not kwargs:
1141+
return dpnp_nancumprod(x1_desc)
11441142

11451143
return call_origin(numpy.nancumprod, x1, **kwargs)
11461144

@@ -1174,11 +1172,9 @@ def nancumsum(x1, **kwargs):
11741172
11751173
"""
11761174

1177-
if not use_origin_backend(x1) and not kwargs:
1178-
if not isinstance(x1, dparray):
1179-
pass
1180-
else:
1181-
return dpnp_nancumsum(x1)
1175+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1176+
if x1_desc and not kwargs:
1177+
return dpnp_nancumsum(x1_desc)
11821178

11831179
return call_origin(numpy.nancumsum, x1, **kwargs)
11841180

@@ -1206,9 +1202,8 @@ def nanprod(x1, **kwargs):
12061202
12071203
"""
12081204

1209-
is_x1_dparray = isinstance(x1, dparray)
1210-
1211-
if (not use_origin_backend(x1) and is_x1_dparray and not kwargs):
1205+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1206+
if x1_desc and not kwargs:
12121207
return dpnp_nanprod(x1)
12131208

12141209
return call_origin(numpy.nanprod, x1, **kwargs)
@@ -1237,9 +1232,8 @@ def nansum(x1, **kwargs):
12371232
12381233
"""
12391234

1240-
is_x1_dparray = isinstance(x1, dparray)
1241-
1242-
if (not use_origin_backend(x1) and is_x1_dparray and not kwargs):
1235+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1236+
if x1_desc and not kwargs:
12431237
return dpnp_nansum(x1)
12441238

12451239
return call_origin(numpy.nansum, x1, **kwargs)
@@ -1360,16 +1354,15 @@ def prod(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, wher
13601354
13611355
"""
13621356

1363-
if not use_origin_backend(x1):
1364-
if not isinstance(x1, dparray):
1365-
pass
1366-
elif out is not None and not isinstance(out, dparray):
1357+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1358+
if x1_desc:
1359+
if out is not None and not isinstance(out, dparray):
13671360
pass
13681361
elif where is not True:
13691362
pass
13701363
else:
1371-
result_obj = dpnp_prod(x1, axis, dtype, out, keepdims, initial, where)
1372-
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
1364+
result_obj = dpnp_prod(x1_desc, axis, dtype, out, keepdims, initial, where)
1365+
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
13731366

13741367
return result
13751368

@@ -1540,23 +1533,22 @@ def sum(x1, axis=None, dtype=None, out=None, keepdims=False, initial=None, where
15401533
15411534
"""
15421535

1543-
if not use_origin_backend(x1):
1544-
if not isinstance(x1, dparray):
1545-
pass
1546-
elif out is not None and not isinstance(out, dparray):
1536+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1537+
if x1_desc:
1538+
if out is not None and not isinstance(out, dparray):
15471539
pass
15481540
elif where is not True:
15491541
pass
15501542
else:
1551-
result_obj = dpnp_sum(x1, axis, dtype, out, keepdims, initial, where)
1543+
result_obj = dpnp_sum(x1_desc, axis, dtype, out, keepdims, initial, where)
15521544
result = dpnp.convert_single_elem_array_to_scalar(result_obj, keepdims)
15531545

15541546
return result
15551547

15561548
return call_origin(numpy.sum, x1, axis=axis, dtype=dtype, out=out, keepdims=keepdims, initial=initial, where=where)
15571549

15581550

1559-
def trapz(y, x=None, dx=1.0, **kwargs):
1551+
def trapz(y, x=None, dx=1.0, axis=-1):
15601552
"""
15611553
Integrate along the given axis using the composite trapezoidal rule.
15621554
@@ -1583,25 +1575,23 @@ def trapz(y, x=None, dx=1.0, **kwargs):
15831575
15841576
"""
15851577

1586-
if not use_origin_backend(y):
1587-
1588-
if not isinstance(y, dparray):
1589-
pass
1590-
elif not isinstance(x, dparray) and x is not None:
1578+
y_desc = dpnp.get_dpnp_descriptor(y)
1579+
if y_desc:
1580+
if not isinstance(x, dparray) and x is not None:
15911581
pass
1592-
elif x is not None and y.size != x.size:
1582+
elif x is not None and y_desc.size != x.size:
15931583
pass
1594-
elif x is not None and y.shape != x.shape:
1584+
elif x is not None and y_desc.shape != x.shape:
15951585
pass
1596-
elif y.ndim > 1:
1586+
elif y_desc.ndim > 1:
15971587
pass
15981588
else:
15991589
if x is None:
1600-
x = dpnp.empty(0, dtype=y.dtype)
1590+
x = dpnp.empty(0, dtype=y_desc.dtype)
16011591

1602-
return dpnp_trapz(y, x, dx)
1592+
return dpnp_trapz(y_desc, x, dx)
16031593

1604-
return call_origin(numpy.trapz, y, x=x, dx=dx, **kwargs)
1594+
return call_origin(numpy.trapz, y, x, dx, axis)
16051595

16061596

16071597
def true_divide(*args, **kwargs):

0 commit comments

Comments
 (0)