|
18 | 18 | import numpy as np
|
19 | 19 |
|
20 | 20 | import dpctl.tensor as dpt
|
21 |
| -from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray |
22 | 21 | from dpctl.tensor._tensor_impl import (
|
23 | 22 | _copy_usm_ndarray_for_reshape,
|
24 | 23 | _ravel_multi_index,
|
@@ -155,32 +154,37 @@ def reshape(X, /, shape, *, order="C", copy=None):
|
155 | 154 | "Reshaping the array requires a copy, but no copying was "
|
156 | 155 | "requested by using copy=False"
|
157 | 156 | )
|
| 157 | + copy_q = X.sycl_queue |
158 | 158 | if copy_required or (copy is True):
|
159 | 159 | # must perform a copy
|
160 | 160 | flat_res = dpt.usm_ndarray(
|
161 | 161 | (X.size,),
|
162 | 162 | dtype=X.dtype,
|
163 | 163 | buffer=X.usm_type,
|
164 |
| - buffer_ctor_kwargs={"queue": X.sycl_queue}, |
| 164 | + buffer_ctor_kwargs={"queue": copy_q}, |
165 | 165 | )
|
166 | 166 | if order == "C":
|
167 | 167 | hev, _ = _copy_usm_ndarray_for_reshape(
|
168 |
| - src=X, dst=flat_res, sycl_queue=X.sycl_queue |
| 168 | + src=X, dst=flat_res, sycl_queue=copy_q |
169 | 169 | )
|
170 |
| - hev.wait() |
171 | 170 | else:
|
172 |
| - for i in range(X.size): |
173 |
| - _copy_from_usm_ndarray_to_usm_ndarray( |
174 |
| - flat_res[i], X[np.unravel_index(i, X.shape, order=order)] |
175 |
| - ) |
| 171 | + X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1)) |
| 172 | + hev, _ = _copy_usm_ndarray_for_reshape( |
| 173 | + src=X_t, dst=flat_res, sycl_queue=copy_q |
| 174 | + ) |
| 175 | + hev.wait() |
176 | 176 | return dpt.usm_ndarray(
|
177 | 177 | tuple(shape), dtype=X.dtype, buffer=flat_res, order=order
|
178 | 178 | )
|
179 | 179 | # can form a view
|
| 180 | + if (len(shape) == X.ndim) and all( |
| 181 | + s1 == s2 for s1, s2 in zip(shape, X.shape) |
| 182 | + ): |
| 183 | + return X |
180 | 184 | return dpt.usm_ndarray(
|
181 | 185 | shape,
|
182 | 186 | dtype=X.dtype,
|
183 | 187 | buffer=X,
|
184 | 188 | strides=tuple(newsts),
|
185 |
| - offset=X.__sycl_usm_array_interface__.get("offset", 0), |
| 189 | + offset=X._element_offset, |
186 | 190 | )
|
0 commit comments