21
21
from dpctl .tensor ._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
22
22
from dpctl .tensor ._tensor_impl import _copy_usm_ndarray_for_reshape
23
23
24
+ __doc__ = "Implementation module for :func:`dpctl.tensor.reshape`."
25
+
24
26
25
27
def _make_unit_indexes (shape ):
26
28
"""
@@ -67,10 +69,8 @@ def reshaped_strides(old_sh, old_sts, new_sh, order="C"):
67
69
]
68
70
]
69
71
valid = all (
70
- [
71
- check_st == old_st or old_dim == 1
72
- for check_st , old_st , old_dim in zip (check_sts , old_sts , old_sh )
73
- ]
72
+ check_st == old_st or old_dim == 1
73
+ for check_st , old_st , old_dim in zip (check_sts , old_sts , old_sh )
74
74
)
75
75
return new_sts if valid else None
76
76
@@ -82,7 +82,7 @@ def reshape(X, newshape, order="C", copy=None):
82
82
Reshapes given usm_ndarray into new shape. Returns a view, if possible,
83
83
a copy otherwise. Memory layout of the copy is controlled by order keyword.
84
84
"""
85
- if type ( X ) is not dpt .usm_ndarray :
85
+ if not isinstance ( X , dpt .usm_ndarray ) :
86
86
raise TypeError
87
87
if not isinstance (newshape , (list , tuple )):
88
88
newshape = (newshape ,)
@@ -99,10 +99,10 @@ def reshape(X, newshape, order="C", copy=None):
99
99
)
100
100
newshape = [operator .index (d ) for d in newshape ]
101
101
negative_ones_count = 0
102
- for i in range ( len ( newshape )) :
103
- if newshape [ i ] == - 1 :
102
+ for nshi in newshape :
103
+ if nshi == - 1 :
104
104
negative_ones_count = negative_ones_count + 1
105
- if (newshape [ i ] < - 1 ) or negative_ones_count > 1 :
105
+ if (nshi < - 1 ) or negative_ones_count > 1 :
106
106
raise ValueError (
107
107
"Target shape should have at most 1 negative "
108
108
"value which can only be -1"
@@ -111,7 +111,7 @@ def reshape(X, newshape, order="C", copy=None):
111
111
v = X .size // (- np .prod (newshape ))
112
112
newshape = [v if d == - 1 else d for d in newshape ]
113
113
if X .size != np .prod (newshape ):
114
- raise ValueError ("Can not reshape into {}" . format ( newshape ) )
114
+ raise ValueError (f "Can not reshape into { newshape } " )
115
115
if X .size :
116
116
newsts = reshaped_strides (X .shape , X .strides , newshape , order = order )
117
117
else :
@@ -143,12 +143,11 @@ def reshape(X, newshape, order="C", copy=None):
143
143
return dpt .usm_ndarray (
144
144
tuple (newshape ), dtype = X .dtype , buffer = flat_res , order = order
145
145
)
146
- else :
147
- # can form a view
148
- return dpt .usm_ndarray (
149
- newshape ,
150
- dtype = X .dtype ,
151
- buffer = X ,
152
- strides = tuple (newsts ),
153
- offset = X .__sycl_usm_array_interface__ .get ("offset" , 0 ),
154
- )
146
+ # can form a view
147
+ return dpt .usm_ndarray (
148
+ newshape ,
149
+ dtype = X .dtype ,
150
+ buffer = X ,
151
+ strides = tuple (newsts ),
152
+ offset = X .__sycl_usm_array_interface__ .get ("offset" , 0 ),
153
+ )
0 commit comments