Skip to content

Commit 5d3d2f4

Browse files
authored
dpnp.choose partial fix for desc. (#860)
1 parent b05efb9 commit 5d3d2f4

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

dpnp/dpnp_algo/dpnp_algo_indexing.pyx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,12 @@ ctypedef void(*custom_indexing_6in_func_ptr_t)(void *, void * , void * , const s
6565
ctypedef void(*fptr_dpnp_nonzero_t)(const void * , void * , const size_t * , const size_t , const size_t)
6666

6767

68-
cpdef dparray dpnp_choose(input, choices):
69-
res_array = dparray(len(input), dtype=choices[0].dtype)
68+
cpdef utils.dpnp_descriptor dpnp_choose(object input, list choices):
69+
cdef shape_type_c obj_shape = utils._object_to_tuple(len(input))
70+
cdef utils.dpnp_descriptor res_array = utils_py.create_output_descriptor_py(obj_shape, choices[0].dtype, None)
71+
7072
for i in range(len(input)):
71-
res_array[i] = (choices[input[i]])[i]
73+
res_array.get_pyobj()[i] = (choices[input[i]])[i]
7274
return res_array
7375

7476

dpnp/dpnp_iface_indexing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def choose(x1, choices, out=None, mode='raise'):
118118
if not val:
119119
pass
120120
else:
121-
return dpnp_choose(x1, choices)
121+
return dpnp_choose(x1, choices).get_pyobj()
122122
else:
123-
return dpnp_choose(x1, choices)
123+
return dpnp_choose(x1, choices).get_pyobj()
124124

125125
return call_origin(numpy.choose, x1, choices, out, mode)
126126

0 commit comments

Comments
 (0)