Skip to content

Commit 4971d64

Browse files
Use operator.index instead of int
1 parent f1be356 commit 4971d64

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

dpctl/tensor/_manipulation_functions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,7 @@ def roll(x, /, shift, *, axis=None):
352352
res = dpt.empty(
353353
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
354354
)
355-
sz = x.size
355+
sz = operator.index(x.size)
356356
shift = (shift % sz) if sz > 0 else 0
357357
dep_evs = _manager.submitted_events
358358
hev, roll_ev = ti._copy_usm_ndarray_for_roll_1d(
@@ -373,8 +373,9 @@ def roll(x, /, shift, *, axis=None):
373373
] * x.ndim
374374
shape = x.shape
375375
for sh, ax in broadcasted:
376-
n_i = shape[ax]
377-
shifts[ax] = (int(shifts[ax] + sh) % int(n_i)) if n_i > 0 else 0
376+
n_i = operator.index(shape[ax])
377+
shifted = operator.index(shifts[ax]) + operator.index(sh)
378+
shifts[ax] = (shifted % n_i) if n_i > 0 else 0
378379
res = dpt.empty(
379380
x.shape, dtype=x.dtype, usm_type=x.usm_type, sycl_queue=exec_q
380381
)

0 commit comments

Comments
 (0)