Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions dpctl/tensor/_manipulation_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ def roll(X, /, shift, *, axis=None):
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
src=X,
dst=res,
shift=shift,
shift=(shift % X.size),
sycl_queue=exec_q,
depends=dep_evs,
)
Expand All @@ -369,9 +369,10 @@ def roll(X, /, shift, *, axis=None):
shifts = [
0,
] * X.ndim
shape = X.shape
for sh, ax in broadcasted:
shifts[ax] += sh

n_i = shape[ax]
shifts[ax] = int(shifts[ax] + sh) % int(n_i) if n_i > 0 else 0
res = dpt.empty(
X.shape, dtype=X.dtype, usm_type=X.usm_type, sycl_queue=exec_q
)
Expand Down
10 changes: 10 additions & 0 deletions dpctl/tests/test_usm_ndarray_manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,6 +657,16 @@ def test_roll_2d(data):
assert_array_equal(Ynp, dpt.asnumpy(Y))


def test_roll_out_bounds_shifts():
"See gh-1857"
get_queue_or_skip()

x = dpt.arange(4)
y = dpt.roll(x, np.uint64(2**63 + 2))
expected = dpt.roll(x, 2)
assert dpt.all(y == expected)


def test_roll_validation():
get_queue_or_skip()

Expand Down
Loading