18
18
import numpy as np
19
19
20
20
import dpctl .tensor as dpt
21
- from dpctl .tensor ._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
22
21
from dpctl .tensor ._tensor_impl import (
23
22
_copy_usm_ndarray_for_reshape ,
24
23
_ravel_multi_index ,
@@ -155,24 +154,25 @@ def reshape(X, /, shape, *, order="C", copy=None):
155
154
"Reshaping the array requires a copy, but no copying was "
156
155
"requested by using copy=False"
157
156
)
157
+ copy_q = X .sycl_queue
158
158
if copy_required or (copy is True ):
159
159
# must perform a copy
160
160
flat_res = dpt .usm_ndarray (
161
161
(X .size ,),
162
162
dtype = X .dtype ,
163
163
buffer = X .usm_type ,
164
- buffer_ctor_kwargs = {"queue" : X . sycl_queue },
164
+ buffer_ctor_kwargs = {"queue" : copy_q },
165
165
)
166
166
if order == "C" :
167
167
hev , _ = _copy_usm_ndarray_for_reshape (
168
- src = X , dst = flat_res , sycl_queue = X . sycl_queue
168
+ src = X , dst = flat_res , sycl_queue = copy_q
169
169
)
170
- hev .wait ()
171
170
else :
172
- for i in range (X .size ):
173
- _copy_from_usm_ndarray_to_usm_ndarray (
174
- flat_res [i ], X [np .unravel_index (i , X .shape , order = order )]
175
- )
171
+ X_t = dpt .permute_dims (X , range (X .ndim - 1 , - 1 , - 1 ))
172
+ hev , _ = _copy_usm_ndarray_for_reshape (
173
+ src = X_t , dst = flat_res , sycl_queue = copy_q
174
+ )
175
+ hev .wait ()
176
176
return dpt .usm_ndarray (
177
177
tuple (shape ), dtype = X .dtype , buffer = flat_res , order = order
178
178
)
@@ -182,5 +182,5 @@ def reshape(X, /, shape, *, order="C", copy=None):
182
182
dtype = X .dtype ,
183
183
buffer = X ,
184
184
strides = tuple (newsts ),
185
- offset = X .__sycl_usm_array_interface__ . get ( "offset" , 0 ) ,
185
+ offset = X ._element_offset ,
186
186
)
0 commit comments