Skip to content

Commit 01ec092

Browse files
Fix bug in reverseV2 for 1D tensors (#1691)
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent b05a993 commit 01ec092

File tree

2 files changed

+23
-5
lines changed

2 files changed

+23
-5
lines changed

tests/test_backend.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3273,6 +3273,11 @@ def test_reversev2_1D_tensor(self):
32733273
# Adds an identity block.
32743274
x_val_shape = [4]
32753275
x_val = np.random.randint(0, 100, x_val_shape).astype(np.float32)
3276+
def func(x):
3277+
x_ = reverse_v2(x, axis=[0])
3278+
return tf.identity(x_, name=_TFOUTPUT)
3279+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3280+
32763281
def func(x):
32773282
x_ = reverse_v2(x, axis=[])
32783283
return tf.identity(x_, name=_TFOUTPUT)

tf2onnx/onnx_opset/tensor.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,20 +2112,29 @@ def version_10(cls, ctx, node, **kwargs):
21122112
axes = axes.tolist()
21132113
len_axes = len(axes)
21142114

2115+
input_rank = ctx.get_rank(node.input[0])
2116+
utils.make_sure(input_rank is not None, "rank of {} is unknown".format(node.input[0]))
2117+
needs_squeeze = False
2118+
if input_rank == 1 and len_axes != 0:
2119+
# ReverseSequence node requires rank >= 2
2120+
utils.make_sure(axes in [[-1], [0]], "Invalid value %s for axes of ReverseV2 of 1d tensor", axes)
2121+
axes = [1]
2122+
new_inp = GraphBuilder(ctx).make_unsqueeze({'data': node.input[0], 'axes': [0]})
2123+
ctx.replace_input(node, node.input[0], new_inp, 0)
2124+
input_rank = 2
2125+
needs_squeeze = True
2126+
21152127
# Store input and output parameters of the ReverseV2 node.
21162128
rv2_in_names = [node.input[0]]
21172129

2118-
input_shape = ctx.get_shape(node.input[0])
2119-
input_rank = len(input_shape)
21202130
input_shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name)
21212131

2122-
# Make sure input shape is not None
2123-
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
2124-
21252132
rv2_node_name = node.name
21262133
# ReverseV2 has a single output.
21272134
rv2_output_dtypes = node.output_dtypes
21282135
rv2_output_shapes = node.output_shapes
2136+
if needs_squeeze and rv2_output_shapes is not None:
2137+
rv2_output_shapes[0] = [1] + rv2_output_shapes[0]
21292138

21302139
# Remove ReverseV2 node from graph.
21312140
ctx.remove_node(rv2_node_name)
@@ -2243,6 +2252,10 @@ def version_10(cls, ctx, node, **kwargs):
22432252
attr={"perm": curr_perm}
22442253
)
22452254

2255+
if needs_squeeze:
2256+
sq_node = GraphBuilder(ctx).make_squeeze({"data": node.output[0], "axes": [0]}, return_node=True)
2257+
ctx.insert_node_on_output(sq_node)
2258+
22462259

22472260
@tf_op("Unique", onnx_op="Unique")
22482261
class Unique:

0 commit comments

Comments
 (0)