Skip to content

Commit 6f8dee0

Browse files
Use single kernel reshaping for order="F" call
1 parent 7bc3124 commit 6f8dee0

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

dpctl/tensor/_reshape.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919

2020
import dpctl.tensor as dpt
21-
from dpctl.tensor._copy_utils import _copy_from_usm_ndarray_to_usm_ndarray
2221
from dpctl.tensor._tensor_impl import (
2322
_copy_usm_ndarray_for_reshape,
2423
_ravel_multi_index,
@@ -155,24 +154,25 @@ def reshape(X, /, shape, *, order="C", copy=None):
155154
"Reshaping the array requires a copy, but no copying was "
156155
"requested by using copy=False"
157156
)
157+
copy_q = X.sycl_queue
158158
if copy_required or (copy is True):
159159
# must perform a copy
160160
flat_res = dpt.usm_ndarray(
161161
(X.size,),
162162
dtype=X.dtype,
163163
buffer=X.usm_type,
164-
buffer_ctor_kwargs={"queue": X.sycl_queue},
164+
buffer_ctor_kwargs={"queue": copy_q},
165165
)
166166
if order == "C":
167167
hev, _ = _copy_usm_ndarray_for_reshape(
168-
src=X, dst=flat_res, sycl_queue=X.sycl_queue
168+
src=X, dst=flat_res, sycl_queue=copy_q
169169
)
170-
hev.wait()
171170
else:
172-
for i in range(X.size):
173-
_copy_from_usm_ndarray_to_usm_ndarray(
174-
flat_res[i], X[np.unravel_index(i, X.shape, order=order)]
175-
)
171+
X_t = dpt.permute_dims(X, range(X.ndim - 1, -1, -1))
172+
hev, _ = _copy_usm_ndarray_for_reshape(
173+
src=X_t, dst=flat_res, sycl_queue=copy_q
174+
)
175+
hev.wait()
176176
return dpt.usm_ndarray(
177177
tuple(shape), dtype=X.dtype, buffer=flat_res, order=order
178178
)
@@ -182,5 +182,5 @@ def reshape(X, /, shape, *, order="C", copy=None):
182182
dtype=X.dtype,
183183
buffer=X,
184184
strides=tuple(newsts),
185-
offset=X.__sycl_usm_array_interface__.get("offset", 0),
185+
offset=X._element_offset,
186186
)

0 commit comments

Comments
 (0)