Skip to content

Commit 795e934

Browse files
authored
scalar support in normalize_axis() (#776)
1 parent 1c47acd commit 795e934

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

dpnp/dpnp_utils/dpnp_algo_utils.pxd

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ Return:
7777
cpdef tuple _object_to_tuple(object obj)
7878
cdef int _normalize_order(order, cpp_bool allow_k=*) except? 0
7979

80-
cpdef dparray_shape_type normalize_axis(dparray_shape_type axis, size_t shape_size)
80+
cpdef dparray_shape_type normalize_axis(object axis, size_t shape_size)
8181
"""
8282
Conversion of the transformation shape axis [-1, 0, 1] into [2, 0, 1] where numbers are `id`s of array shape axis
8383
"""

dpnp/dpnp_utils/dpnp_algo_utils.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,12 @@ cpdef nd2dp_array(arr):
355355
return result
356356

357357

358-
cpdef dparray_shape_type normalize_axis(dparray_shape_type axis, size_t shape_size_inp):
358+
cpdef dparray_shape_type normalize_axis(object axis_obj, size_t shape_size_inp):
359359
"""
360360
Conversion of the transformation shape axis [-1, 0, 1] into [2, 0, 1] where numbers are `id`s of array shape axis
361361
"""
362362

363+
cdef dparray_shape_type axis = _object_to_tuple(axis_obj) # axis_obj might be a scalar
363364
cdef ssize_t shape_size = shape_size_inp # convert type for comparison with axis id
364365

365366
cdef size_t axis_size = axis.size()

0 commit comments

Comments
 (0)