@@ -311,7 +311,7 @@ def flip(X, /, *, axis=None):
311311 return X [indexer ]
312312
313313
314- def roll (X , / , shift , * , axis = None ):
314+ def roll (x , / , shift , * , axis = None ):
315315 """
316316 roll(x, shift, axis)
317317
@@ -343,41 +343,45 @@ def roll(X, /, shift, *, axis=None):
343343 `device` attributes as `x` and whose elements are shifted relative
344344 to `x`.
345345 """
346- if not isinstance (X , dpt .usm_ndarray ):
347- raise TypeError (f"Expected usm_ndarray type, got { type (X )} ." )
348- exec_q = X .sycl_queue
346+ if not isinstance (x , dpt .usm_ndarray ):
347+ raise TypeError (f"Expected usm_ndarray type, got { type (x )} ." )
348+ exec_q = x .sycl_queue
349349 _manager = dputils .SequentialOrderManager [exec_q ]
350350 if axis is None :
351351 shift = operator .index (shift )
352- dep_evs = _manager .submitted_events
353352 res = dpt .empty (
354- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
353+ x .shape , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
355354 )
355+ sz = operator .index (x .size )
356+ shift = (shift % sz ) if sz > 0 else 0
357+ dep_evs = _manager .submitted_events
356358 hev , roll_ev = ti ._copy_usm_ndarray_for_roll_1d (
357- src = X ,
359+ src = x ,
358360 dst = res ,
359361 shift = shift ,
360362 sycl_queue = exec_q ,
361363 depends = dep_evs ,
362364 )
363365 _manager .add_event_pair (hev , roll_ev )
364366 return res
365- axis = normalize_axis_tuple (axis , X .ndim , allow_duplicate = True )
367+ axis = normalize_axis_tuple (axis , x .ndim , allow_duplicate = True )
366368 broadcasted = np .broadcast (shift , axis )
367369 if broadcasted .ndim > 1 :
368370 raise ValueError ("'shift' and 'axis' should be scalars or 1D sequences" )
369371 shifts = [
370372 0 ,
371- ] * X .ndim
373+ ] * x .ndim
374+ shape = x .shape
372375 for sh , ax in broadcasted :
373- shifts [ax ] += sh
374-
376+ n_i = operator .index (shape [ax ])
377+ shifted = shifts [ax ] + operator .index (sh )
378+ shifts [ax ] = (shifted % n_i ) if n_i > 0 else 0
375379 res = dpt .empty (
376- X .shape , dtype = X .dtype , usm_type = X .usm_type , sycl_queue = exec_q
380+ x .shape , dtype = x .dtype , usm_type = x .usm_type , sycl_queue = exec_q
377381 )
378382 dep_evs = _manager .submitted_events
379383 ht_e , roll_ev = ti ._copy_usm_ndarray_for_roll_nd (
380- src = X , dst = res , shifts = shifts , sycl_queue = exec_q , depends = dep_evs
384+ src = x , dst = res , shifts = shifts , sycl_queue = exec_q , depends = dep_evs
381385 )
382386 _manager .add_event_pair (ht_e , roll_ev )
383387 return res
0 commit comments