Skip to content

Commit c24dc6f

Browse files
authored
fix tf.roll with axes=-1 (#1338)
Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent a4aeb1a commit c24dc6f

File tree

2 files changed

+12
-0
lines changed

2 files changed

+12
-0
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -911,6 +911,15 @@ def func(x, shift):
911911
return tf.identity(x_, name=_TFOUTPUT)
912912
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: shift_val})
913913

914+
@check_opset_min_version(10, "Slice")
915+
def test_roll_neg_axis(self):
916+
def func(input_ids):
917+
shifted_input_ids = tf.cast(input_ids, tf.int32)
918+
shifted_input_ids = tf.roll(shifted_input_ids, 1, axis=-1)
919+
return tf.identity(shifted_input_ids, name=_TFOUTPUT)
920+
x_val = np.array([[0, 1, 2, 3, 4, 5, 6, 7], [1, 2, 3, 4, 5, 6, 7, 8]], dtype=np.int64)
921+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
922+
914923
@check_tf_min_version("2.2")
915924
def test_large_model_format(self):
916925
x_val = np.array([2.0], dtype=np.float32)
@@ -4581,5 +4590,6 @@ def func(x):
45814590
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
45824591

45834592

4593+
45844594
if __name__ == '__main__':
45854595
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,8 @@ class Roll:
356356
def any_version(cls, opset, ctx, node, **kwargs):
357357
utils.make_sure(node.inputs[2].is_const(), "Can only convert Roll is axis is const")
358358
axes = node.inputs[2].get_tensor_value()
359+
if axes == -1:
360+
axes = len(ctx.get_shape(node.input[0])) + axes
359361
if not isinstance(axes, list):
360362
axes = [axes]
361363
shifts_dtype = ctx.get_dtype(node.input[1])

0 commit comments

Comments
 (0)