Skip to content

Commit bb55ff9

Browse files
authored
dpnp.transpose() to desc (#849)
1 parent 185c9a8 commit bb55ff9

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

dpnp/dpnp_algo/dpnp_algo.pxd

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ cpdef dpnp_descriptor dpnp_subtract(dpnp_descriptor x1_obj, dpnp_descriptor x2_o
339339
"""
340340
Array manipulation routines
341341
"""
342-
cpdef dparray dpnp_repeat(dpnp_descriptor array1, repeats, axes=*)
343-
cpdef dparray dpnp_transpose(dpnp_descriptor array1, axes=*)
342+
cpdef dpnp_descriptor dpnp_repeat(dpnp_descriptor array1, repeats, axes=*)
343+
cpdef dpnp_descriptor dpnp_transpose(dpnp_descriptor array1, axes=*)
344344

345345

346346
"""

dpnp/dpnp_algo/dpnp_algo_manipulation.pyx

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -126,22 +126,22 @@ cpdef utils.dpnp_descriptor dpnp_expand_dims(utils.dpnp_descriptor in_array, axi
126126
return result
127127

128128

129-
cpdef dparray dpnp_repeat(utils.dpnp_descriptor array1, repeats, axes=None):
129+
cpdef utils.dpnp_descriptor dpnp_repeat(utils.dpnp_descriptor array1, repeats, axes=None):
130130
cdef DPNPFuncType param1_type = dpnp_dtype_to_DPNPFuncType(array1.dtype)
131131

132132
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_REPEAT, param1_type, param1_type)
133133

134-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
135-
cdef long new_size = array1.size * repeats
136-
cdef dparray result = dparray((new_size, ), dtype=array1.dtype)
134+
# ceate result array with type given by FPTR data
135+
cdef shape_type_c result_shape = (array1.size * repeats, )
136+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
137137

138138
cdef fptr_dpnp_repeat_t func = <fptr_dpnp_repeat_t > kernel_data.ptr
139139
func(array1.get_data(), result.get_data(), repeats, array1.size)
140140

141141
return result
142142

143143

144-
cpdef dparray dpnp_transpose(utils.dpnp_descriptor array1, axes=None):
144+
cpdef utils.dpnp_descriptor dpnp_transpose(utils.dpnp_descriptor array1, axes=None):
145145
cdef shape_type_c input_shape = array1.shape
146146
cdef size_t input_shape_size = array1.ndim
147147
cdef shape_type_c result_shape = shape_type_c(input_shape_size, 1)
@@ -173,9 +173,8 @@ cpdef dparray dpnp_transpose(utils.dpnp_descriptor array1, axes=None):
173173
# get the FPTR data structure
174174
cdef DPNPFuncData kernel_data = get_dpnp_function_ptr(DPNP_FN_TRANSPOSE, param1_type, param1_type)
175175

176-
result_type = dpnp_DPNPFuncType_to_dtype(< size_t > kernel_data.return_type)
177176
# ceate result array with type given by FPTR data
178-
cdef dparray result = dparray(result_shape, dtype=result_type)
177+
cdef utils.dpnp_descriptor result = utils.create_output_descriptor(result_shape, kernel_data.return_type, None)
179178

180179
cdef fptr_custom_elemwise_transpose_1in_1out_t func = <fptr_custom_elemwise_transpose_1in_1out_t > kernel_data.ptr
181180
# call FPTR function

dpnp/dpnp_iface_manipulation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def repeat(x1, repeats, axis=None):
460460
pass
461461
else:
462462
repeat_val = repeats if dpnp.isscalar(repeats) else repeats[0]
463-
return dpnp_repeat(x1_desc, repeat_val, axis)
463+
return dpnp_repeat(x1_desc, repeat_val, axis).get_pyobj()
464464

465465
return call_origin(numpy.repeat, x1, repeats, axis)
466466

@@ -651,7 +651,7 @@ def transpose(x1, axes=None):
651651
"""
652652
axes = None
653653

654-
result = dpnp_transpose(x1_desc, axes)
654+
result = dpnp_transpose(x1_desc, axes).get_pyobj()
655655

656656
return result
657657

0 commit comments

Comments
 (0)