Skip to content

Commit fc91639

Browse files
Fix Roll op for negative shifts (#1616)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 80d6080 commit fc91639

File tree

2 files changed

+15
-3
lines changed

2 files changed

+15
-3
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,6 +1084,16 @@ def func(input_ids):
10841084
x_val = np.array([[0, 1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8]], dtype=np.int64)
10851085
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
10861086

1087+
@check_opset_min_version(10, "Slice")
1088+
def test_roll_neg_shift(self):
1089+
x_val = np.arange(4 * 3 * 5 * 2, dtype=np.float32).reshape((4, 3, 5, 2))
1090+
shift_val = np.array([-2, 13, -3], dtype=np.int32)
1091+
axes_val = np.array([1, 2, -1], dtype=np.int32)
1092+
def func(x, shift):
1093+
x_ = tf.roll(x, shift, axes_val)
1094+
return tf.identity(x_, name=_TFOUTPUT)
1095+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: shift_val})
1096+
10871097
@check_tf_min_version("2.2")
10881098
def test_large_model_format(self):
10891099
x_val = np.array([2.0], dtype=np.float32)

tf2onnx/onnx_opset/tensor.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -369,10 +369,11 @@ class Roll:
369369
def any_version(cls, opset, ctx, node, **kwargs):
370370
utils.make_sure(node.inputs[2].is_const(), "Can only convert Roll is axis is const")
371371
axes = node.inputs[2].get_tensor_value()
372-
if axes == -1:
373-
axes = len(ctx.get_shape(node.input[0])) + axes
374372
if not isinstance(axes, list):
375373
axes = [axes]
374+
rank = ctx.get_rank(node.input[0])
375+
axes = [a if a >= 0 else a + rank for a in axes]
376+
376377
shifts_dtype = ctx.get_dtype(node.input[1])
377378
if shifts_dtype != TensorProto.INT64:
378379
shifts_casted = ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64).output[0]
@@ -395,7 +396,8 @@ def any_version(cls, opset, ctx, node, **kwargs):
395396
for axis, shift in zip(axes, shifts_split):
396397
len_along_axis = GraphBuilder(ctx).make_slice(
397398
{"data": shape_node.output[0], "ends": [axis + 1], "starts": [axis]})
398-
remaining_len = ctx.make_node("Sub", [len_along_axis, shift], op_name_scope=node.name).output[0]
399+
shift_mod = ctx.make_node("Mod", [shift, len_along_axis]).output[0]
400+
remaining_len = ctx.make_node("Sub", [len_along_axis, shift_mod], op_name_scope=node.name).output[0]
399401
axes_const = ctx.make_const(utils.make_name("axes_const"), np.array([axis], np.int64)).output[0]
400402
slice_one = ctx.make_node("Slice", [data, zero_const, remaining_len, axes_const], op_name_scope=node.name)
401403
slice_two = ctx.make_node("Slice", [data, remaining_len, len_along_axis, axes_const],

0 commit comments

Comments
 (0)