Skip to content

Commit 0b63a38

Browse files
authored
argmin and argmax to desc (#846)
1 parent 054d54c commit 0b63a38

File tree

3 files changed

+12
-10
lines changed

3 files changed

+12
-10
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,8 @@ cpdef dpnp_descriptor dpnp_sort(dpnp_descriptor array1)
360360
"""
361361
Searching functions
362362
"""
363-
cpdef dparray dpnp_argmax(dpnp_descriptor array1)
364-
cpdef dparray dpnp_argmin(dpnp_descriptor array1)
363+
cpdef dpnp_descriptor dpnp_argmax(dpnp_descriptor array1)
364+
cpdef dpnp_descriptor dpnp_argmin(dpnp_descriptor array1)
365365

366366
"""
367367
Trigonometric functions

dpnp/dpnp_algo/dpnp_algo_searching.pyx

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,15 @@ __all__ += [
4444
ctypedef void(*custom_search_1in_1out_func_ptr_t)(void * , void * , size_t)
4545

4646

47-
cpdef dparray dpnp_argmax(utils.dpnp_descriptor in_array1):
47+
cpdef utils.dpnp_descriptor dpnp_argmax(utils.dpnp_descriptor in_array1):
4848
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
4949
cdef DPNPFuncType output_type = dpnp_dtype_to_DPNPFuncType(dpnp.int64)
5050

5151
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ARGMAX, param1_type, output_type)
5252

53-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
54-
cdef dparray result = dparray((1,), dtype=result_type)
53+
# ceate result array with type given by FPTR data
54+
cdef shape_type_c result_shape = (1,)
55+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
5556

5657
cdef custom_search_1in_1out_func_ptr_t func = <custom_search_1in_1out_func_ptr_t > kernel_data.ptr
5758

@@ -60,14 +61,15 @@ cpdef dparray dpnp_argmax(utils.dpnp_descriptor in_array1):
6061
return result
6162

6263

63-
cpdef dparray dpnp_argmin(utils.dpnp_descriptor in_array1):
64+
cpdef utils.dpnp_descriptor dpnp_argmin(utils.dpnp_descriptor in_array1):
6465
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(in_array1.dtype)
6566
cdef DPNPFuncType output_type = dpnp_dtype_to_DPNPFuncType(dpnp.int64)
6667

6768
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_ARGMIN, param1_type, output_type)
6869

69-
result_type = dpnp_DPNPFuncType_to_dtype( < size_t > kernel_data.return_type)
70-
cdef dparray result = dparray((1,), dtype=result_type)
70+
# ceate result array with type given by FPTR data
71+
cdef shape_type_c result_shape = (1,)
72+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
7173

7274
cdef custom_search_1in_1out_func_ptr_t func = <custom_search_1in_1out_func_ptr_t > kernel_data.ptr
7375

dpnp/dpnp_iface_searching.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def argmax(x1, axis=None, out=None):
102102
elif out is not None:
103103
pass
104104
else:
105-
result_obj = dpnp_argmax(x1_desc)
105+
result_obj = dpnp_argmax(x1_desc).get_pyobj()
106106
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
107107

108108
return result
@@ -157,7 +157,7 @@ def argmin(x1, axis=None, out=None):
157157
elif out is not None:
158158
pass
159159
else:
160-
result_obj = dpnp_argmin(x1_desc)
160+
result_obj = dpnp_argmin(x1_desc).get_pyobj()
161161
result = dpnp.convert_single_elem_array_to_scalar(result_obj)
162162

163163
return result

0 commit comments

Comments
 (0)