Skip to content

Commit 87d8486

Browse files
authored
argmin/argmax to desc (#767)
1 parent 457c79e commit 87d8486

File tree

3 files changed

+28
-30
lines changed

3 files changed

+28
-30
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,8 @@ cpdef dparray dpnp_sort(dpnp_descriptor array1)
323323
"""
324324
Searching functions
325325
"""
326-
cpdef dparray dpnp_argmax(dparray array1)
327-
cpdef dparray dpnp_argmin(dparray array1)
326+
cpdef dparray dpnp_argmax(dpnp_descriptor array1)
327+
cpdef dparray dpnp_argmin(dpnp_descriptor array1)
328328

329329
"""
330330
Trigonometric functions

dpnp/dpnp_algo/dpnp_algo_searching.pyx

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ __all__ += [
4444
ctypedef void(*custom_search_1in_1out_func_ptr_t)(void * , void * , size_t)
4545

4646

47-
cpdef dparray dpnp_argmax(dparray in_array1):
47+
cpdef dparray dpnp_argmax(utils.dpnp_descriptor in_array1):
4848
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
49-
cdef DPNPFuncType output_type = dpnp_dtype_to_DPNPFuncType(numpy.int64)
49+
cdef DPNPFuncType output_type = dpnp_dtype_to_DPNPFuncType(dpnp.int64)
5050

5151
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ARGMAX, param1_type, output_type)
5252

@@ -60,9 +60,9 @@ cpdef dparray dpnp_argmax(dparray in_array1):
6060
return result
6161

6262

63-
cpdef dparray dpnp_argmin(dparray in_array1):
63+
cpdef dparray dpnp_argmin(utils.dpnp_descriptor in_array1):
6464
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
65-
cdef DPNPFuncType output_type = dpnp_dtype_to_DPNPFuncType(numpy.int64)
65+
cdef DPNPFuncType output_type = dpnp_dtype_to_DPNPFuncType(dpnp.int64)
6666

6767
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ARGMIN, param1_type, output_type)
6868

dpnp/dpnp_iface_searching.py

Lines changed: 22 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
]
5555

5656

57-
def argmax(in_array1, axis=None, out=None):
57+
def argmax(x1, axis=None, out=None):
5858
"""
5959
Returns the indices of the maximum values along an axis.
6060
@@ -94,23 +94,22 @@ def argmax(in_array1, axis=None, out=None):
9494
9595
"""
9696

97-
is_dparray1 = isinstance(in_array1, dparray)
98-
99-
if (not use_origin_backend(in_array1) and is_dparray1):
97+
x1_desc = dpnp.get_dpnp_descriptor(x1)
98+
if x1_desc:
10099
if axis is not None:
101-
checker_throw_value_error("argmax", "axis", type(axis), None)
102-
if out is not None:
103-
checker_throw_value_error("argmax", "out", type(out), None)
104-
105-
result_obj = dpnp_argmax(in_array1)
106-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
100+
pass
101+
elif out is not None:
102+
pass
103+
else:
104+
result_obj = dpnp_argmax(x1_desc)
105+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
107106

108-
return result
107+
return result
109108

110-
return numpy.argmax(in_array1, axis, out)
109+
return call_origin(numpy.argmax, x1, axis, out)
111110

112111

113-
def argmin(in_array1, axis=None, out=None):
112+
def argmin(x1, axis=None, out=None):
114113
"""
115114
Returns the indices of the minimum values along an axis.
116115
@@ -150,17 +149,16 @@ def argmin(in_array1, axis=None, out=None):
150149
151150
"""
152151

153-
is_dparray1 = isinstance(in_array1, dparray)
154-
155-
if (not use_origin_backend(in_array1) and is_dparray1):
152+
x1_desc = dpnp.get_dpnp_descriptor(x1)
153+
if x1_desc:
156154
if axis is not None:
157-
checker_throw_value_error("argmin", "axis", type(axis), None)
158-
if out is not None:
159-
checker_throw_value_error("argmin", "out", type(out), None)
160-
161-
result_obj = dpnp_argmin(in_array1)
162-
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
155+
pass
156+
elif out is not None:
157+
pass
158+
else:
159+
result_obj = dpnp_argmin(x1_desc)
160+
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
163161

164-
return result
162+
return result
165163

166-
return numpy.argmin(in_array1, axis, out)
164+
return call_origin(numpy.argmin, x1, axis, out)

0 commit comments

Comments
 (0)