@@ -369,10 +369,11 @@ class Roll:
369
369
def any_version (cls , opset , ctx , node , ** kwargs ):
370
370
utils .make_sure (node .inputs [2 ].is_const (), "Can only convert Roll is axis is const" )
371
371
axes = node .inputs [2 ].get_tensor_value ()
372
- if axes == - 1 :
373
- axes = len (ctx .get_shape (node .input [0 ])) + axes
374
372
if not isinstance (axes , list ):
375
373
axes = [axes ]
374
+ rank = ctx .get_rank (node .input [0 ])
375
+ axes = [a if a >= 0 else a + rank for a in axes ]
376
+
376
377
shifts_dtype = ctx .get_dtype (node .input [1 ])
377
378
if shifts_dtype != TensorProto .INT64 :
378
379
shifts_casted = ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = TensorProto .INT64 ).output [0 ]
@@ -395,7 +396,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
395
396
for axis , shift in zip (axes , shifts_split ):
396
397
len_along_axis = GraphBuilder (ctx ).make_slice (
397
398
{"data" : shape_node .output [0 ], "ends" : [axis + 1 ], "starts" : [axis ]})
398
- remaining_len = ctx .make_node ("Sub" , [len_along_axis , shift ], op_name_scope = node .name ).output [0 ]
399
+ shift_mod = ctx .make_node ("Mod" , [shift , len_along_axis ]).output [0 ]
400
+ remaining_len = ctx .make_node ("Sub" , [len_along_axis , shift_mod ], op_name_scope = node .name ).output [0 ]
399
401
axes_const = ctx .make_const (utils .make_name ("axes_const" ), np .array ([axis ], np .int64 )).output [0 ]
400
402
slice_one = ctx .make_node ("Slice" , [data , zero_const , remaining_len , axes_const ], op_name_scope = node .name )
401
403
slice_two = ctx .make_node ("Slice" , [data , remaining_len , len_along_axis , axes_const ],
0 commit comments