Skip to content

Commit de7cf34

Browse files
authored
Use desc in sort (#764)
1 parent f549605 commit de7cf34

File tree

4 files changed

+41
-43
lines changed

4 files changed

+41
-43
lines changed

dpnp/dpnp_algo/dpnp_algo_sorting.pyx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,9 @@ cpdef dparray dpnp_argsort(utils.dpnp_descriptor x1):
5050
return call_fptr_1in_1out(DPNP_FN_ARGSORT, x1, x1.shape)
5151

5252

53-
cpdef dparray dpnp_partition(dparray arr, int kth, axis=-1, kind='introselect', order=None):
53+
cpdef dparray dpnp_partition(utils.dpnp_descriptor arr, int kth, axis=-1, kind='introselect', order=None):
54+
cdef dparray_shape_type shape1 = arr.shape
55+
5456
cdef size_t kth_ = kth if kth >= 0 else (arr.ndim + kth)
5557
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(arr.dtype)
5658

@@ -62,12 +64,12 @@ cpdef dparray dpnp_partition(dparray arr, int kth, axis=-1, kind='introselect',
6264

6365
cdef fptr_dpnp_partition_t func = <fptr_dpnp_partition_t > kernel_data.ptr
6466

65-
func(arr.get_data(), arr2.get_data(), result.get_data(), kth_, < size_t * > arr._dparray_shape.data(), arr.ndim)
67+
func(arr.get_data(), arr2.get_data(), result.get_data(), kth_, < size_t * > shape1.data(), arr.ndim)
6668

6769
return result
6870

6971

70-
cpdef dparray dpnp_searchsorted(dparray arr, dparray v, side='left'):
72+
cpdef dparray dpnp_searchsorted(utils.dpnp_descriptor arr, utils.dpnp_descriptor v, side='left'):
7173
if side is 'left':
7274
side_ = True
7375
else:

dpnp/dpnp_iface_counting.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@
4141

4242

4343
from dpnp.dpnp_algo.dpnp_algo import * # TODO need to investigate why dpnp.dpnp_algo can not be used
44-
from dpnp.dparray import dparray
44+
45+
import dpnp
46+
import numpy
4547

4648
# full module name because dpnp_iface_counting loaded from cython too early
4749
from dpnp.dpnp_utils.dpnp_algo_utils import *
@@ -54,7 +56,7 @@
5456
]
5557

5658

57-
def count_nonzero(in_array1, axis=None, *, keepdims=False):
59+
def count_nonzero(x1, axis=None, *, keepdims=False):
5860
"""
5961
Counts the number of non-zero values in the array ``in_array1``.
6062
@@ -77,17 +79,16 @@ def count_nonzero(in_array1, axis=None, *, keepdims=False):
7779
7880
"""
7981

80-
is_dparray1 = isinstance(in_array1, dparray)
81-
82-
if (not use_origin_backend(in_array1) and is_dparray1):
82+
x1_desc = dpnp.get_dpnp_descriptor(x1)
83+
if x1_desc:
8384
if axis is not None:
84-
checker_throw_value_error("count_nonzero", "axis", type(axis), None)
85-
if keepdims is not False:
86-
checker_throw_value_error("count_nonzero", "keepdims", keepdims, False)
87-
88-
result_obj = dpnp_count_nonzero(in_array1)
89-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
85+
pass
86+
elif keepdims is not False:
87+
pass
88+
else:
89+
result_obj = dpnp_count_nonzero(x1)
90+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
9091

91-
return result
92+
return result
9293

93-
return numpy.count_nonzero(in_array1, axis, keepdims=keepdims)
94+
return numpy.count_nonzero(x1, axis, keepdims=keepdims)

dpnp/dpnp_iface_sorting.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def argsort(in_array1, axis=-1, kind=None, order=None):
104104
return numpy.argsort(in_array1, axis, kind, order)
105105

106106

107-
def partition(arr, kth, axis=-1, kind='introselect', order=None):
107+
def partition(x1, kth, axis=-1, kind='introselect', order=None):
108108
"""
109109
Return a partitioned copy of an array.
110110
For full documentation refer to :obj:`numpy.partition`.
@@ -115,12 +115,12 @@ def partition(arr, kth, axis=-1, kind='introselect', order=None):
115115
Input kth is supported as :obj:`int`.
116116
Parameters ``axis``, ``kind`` and ``order`` are supported only with default values.
117117
"""
118-
if not use_origin_backend():
119-
if not isinstance(arr, dparray):
120-
pass
121-
elif not isinstance(kth, int):
118+
119+
x1_desc = dpnp.get_dpnp_descriptor(x1)
120+
if x1_desc:
121+
if not isinstance(kth, int):
122122
pass
123-
elif kth >= arr.shape[arr.ndim - 1] or arr.ndim + kth < 0:
123+
elif kth >= x1_desc.shape[x1_desc.ndim - 1] or x1_desc.ndim + kth < 0:
124124
pass
125125
elif axis != -1:
126126
pass
@@ -129,12 +129,12 @@ def partition(arr, kth, axis=-1, kind='introselect', order=None):
129129
elif order is not None:
130130
pass
131131
else:
132-
return dpnp_partition(arr, kth, axis, kind, order)
132+
return dpnp_partition(x1_desc, kth, axis, kind, order)
133133

134-
return call_origin(numpy.partition, arr, kth, axis, kind, order)
134+
return call_origin(numpy.partition, x1, kth, axis, kind, order)
135135

136136

137-
def searchsorted(arr, v, side='left', sorter=None):
137+
def searchsorted(x1, x2, side='left', sorter=None):
138138
"""
139139
Find indices where elements should be inserted to maintain order.
140140
For full documentation refer to :obj:`numpy.searchsorted`.
@@ -146,27 +146,24 @@ def searchsorted(arr, v, side='left', sorter=None):
146146
Input side is supported only values ``left``, ``right``.
147147
Parameters ``sorter`` is supported only with default values.
148148
"""
149-
if not use_origin_backend():
150-
if not isinstance(arr, dparray):
151-
pass
152-
elif not isinstance(v, dparray):
153-
pass
154-
elif arr.ndim != 1:
149+
150+
x1_desc = dpnp.get_dpnp_descriptor(x1)
151+
x2_desc = dpnp.get_dpnp_descriptor(x2)
152+
if 0 and x1_desc and x2_desc:
153+
if x1_desc.ndim != 1:
155154
pass
156-
elif arr.dtype != v.dtype:
155+
elif x1_desc.dtype != x2_desc.dtype:
157156
pass
158157
elif side not in ['left', 'right']:
159158
pass
160159
elif sorter is not None:
161160
pass
162-
elif arr.size < 2:
163-
pass
164-
elif dpnp.sort(arr) != arr:
161+
elif x1_desc.size < 2:
165162
pass
166163
else:
167-
return dpnp_searchsorted(arr, v, side=side)
164+
return dpnp_searchsorted(x1_desc, x2_desc, side=side)
168165

169-
return call_origin(numpy.searchsorted, arr, v, side=side, sorter=sorter)
166+
return call_origin(numpy.searchsorted, x1, x2, side=side, sorter=sorter)
170167

171168

172169
def sort(x1, **kwargs):

dpnp/dpnp_iface_trigonometric.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1061,10 +1061,8 @@ def unwrap(x1):
10611061
10621062
"""
10631063

1064-
if (use_origin_backend(x1)):
1065-
return call_origin(numpy.unwrap, x1, **kwargs)
1066-
1067-
if not isinstance(x1, dparray):
1068-
raise TypeError(f"DPNP unwrap(): Unsupported x1={type(x1)}")
1064+
x1_desc = dpnp.get_dpnp_descriptor(x1)
1065+
if x1_desc:
1066+
return dpnp_unwrap(x1_desc)
10691067

1070-
return dpnp_unwrap(x1)
1068+
return call_origin(numpy.unwrap, x1, **kwargs)

0 commit comments

Comments
 (0)