Skip to content

Commit 8b8d5ea

Browse files
authored
Merge pull request #909 from jignparm/jignparm/fix_reversev2
ReverseV2 - fix shape computations
2 parents a1c8f8b + bc2e0a5 commit 8b8d5ea

File tree

1 file changed

+24
-41
lines changed

1 file changed

+24
-41
lines changed

tf2onnx/onnx_opset/tensor.py

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,18 +1624,17 @@ def version_10(cls, ctx, node, **kwargs):
16241624
rv2_in_names = [node.input[0]]
16251625

16261626
input_shape = ctx.get_shape(node.input[0])
1627+
input_rank = len(input_shape)
1628+
input_shape_node = ctx.make_node("Shape", [node.input[0]], op_name_scope=node.name)
1629+
16271630
# Make sure input shape is not None
16281631
utils.make_sure(input_shape is not None, "shape of {} is None".format(node.input[0]))
16291632

1630-
input_rank = len(input_shape)
1631-
16321633
rv2_node_name = node.name
16331634
# ReverseV2 has a single output.
16341635
rv2_output_dtypes = node.output_dtypes
16351636
rv2_output_shapes = node.output_shapes
16361637

1637-
const_name_root = rv2_node_name + '_Const'
1638-
16391638
# Remove ReverseV2 node from graph.
16401639
ctx.remove_node(rv2_node_name)
16411640

@@ -1689,36 +1688,20 @@ def version_10(cls, ctx, node, **kwargs):
16891688

16901689
inputs = [new_node.output[0]]
16911690

1691+
const_one_name = utils.make_name(f'const_one')
1692+
const_one = ctx.make_const(name=const_one_name, np_val=np.array([1], dtype=np.int64))
1693+
const_axis_name = utils.make_name(f'const_{axis}')
1694+
const_axis = ctx.make_const(name=const_axis_name, np_val=np.array([axis], dtype=np.int64))
1695+
16921696
# Add a Constant node (seq_len) for ReverseSequence.
1693-
if ctx.opset >= 11:
1694-
batch_shape = ctx.make_node("Shape", [inputs[-1]])
1695-
const_one = ctx.make_const(utils.make_name(node.name + "_const_one"), np.array([1], dtype=np.int64))
1696-
const_two = ctx.make_const(utils.make_name(node.name + "_const_two"), np.array([2], dtype=np.int64))
1697-
batch_size = ctx.make_node("Slice",
1698-
[batch_shape.output[0], const_one.output[0], const_two.output[0]])
1699-
input_shape = ctx.make_node("Shape", [node.input[0]])
1700-
const_axis = ctx.make_const(utils.make_name(node.name + "_const_axis"),
1701-
np.array([axis], dtype=np.int64))
1702-
const_axis_next = ctx.make_const(utils.make_name(node.name + "_const_axis_next"),
1703-
np.array([axis + 1], dtype=np.int64))
1704-
input_axis = ctx.make_node("Slice",
1705-
[input_shape.output[0], const_axis.output[0], const_axis_next.output[0]])
1706-
seq_array = ctx.make_node("Expand", [input_axis.output[0], batch_size.output[0]])
1707-
inputs.append(seq_array.output[0])
1708-
else:
1709-
# Index 1 for the shape should not return 0
1710-
# since the input must have rank >= 2.
1711-
rs_batch_size = ctx.get_shape(inputs[-1])[1]
1712-
# Make sure rs_batch_size and input_shape[axis] are not -1 each
1713-
utils.make_sure(input_shape[axis] is not -1 \
1714-
, "shape of axis {} is unknown".format(axis))
1715-
utils.make_sure(rs_batch_size is not -1 \
1716-
, "ReverseSequence batch size for axis {} is unknown".format(axis))
1717-
seq_list = [input_shape[axis]] * rs_batch_size
1718-
seq_array = np.asarray(seq_list, dtype=np.int64) # dtype should be int64
1719-
const_seq_name = utils.make_name(const_name_root)
1720-
new_node = ctx.make_const(name=const_seq_name, np_val=seq_array)
1721-
inputs.append(new_node.output[0])
1697+
# Index 1 for the shape should not return 0, since rank(input) >=2
1698+
input_shape = ctx.make_node("Shape", [inputs[-1]], op_name_scope=rv2_node_name)
1699+
batch_size = ctx.make_node("Gather", [input_shape.output[0], const_one.output[0]],
1700+
op_name_scope=rv2_node_name)
1701+
axis_dim = ctx.make_node("Gather", [input_shape_node.output[0], const_axis.output[0]],
1702+
op_name_scope=rv2_node_name)
1703+
seq_array = ctx.make_node("Expand", [axis_dim.output[0], batch_size.output[0]])
1704+
inputs.append(seq_array.output[0])
17221705

17231706
# Add a ReverseSequence node.
17241707

@@ -1942,21 +1925,21 @@ def version_11(cls, ctx, node, **kwargs):
19421925
gap_pos_k = gap_pos_k_graph.make_node('Concat', [const_zero.output[0],
19431926
processed_gap.output[0]],
19441927
attr={'axis': 0}) \
1945-
if align.startswith('LEFT') \
1946-
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
1947-
const_zero.output[0]],
1948-
attr={'axis': 0})
1928+
if align.startswith('LEFT') \
1929+
else gap_pos_k_graph.make_node('Concat', [processed_gap.output[0],
1930+
const_zero.output[0]],
1931+
attr={'axis': 0})
19491932
gap_pos_k_graph.add_graph_output(gap_pos_k.output[0], TensorProto.INT64, [-1])
19501933
# gap_neg_k_graph
19511934
gap_neg_k_graph = body_graph.create_new_graph_with_same_config()
19521935
gap_neg_k_graph.parent_graph = body_graph
19531936
gap_neg_k = gap_neg_k_graph.make_node('Concat', [const_zero.output[0],
19541937
processed_gap.output[0]],
19551938
attr={'axis': 0}) \
1956-
if align.endswith('LEFT') \
1957-
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
1958-
const_zero.output[0]],
1959-
attr={'axis': 0})
1939+
if align.endswith('LEFT') \
1940+
else gap_neg_k_graph.make_node('Concat', [processed_gap.output[0],
1941+
const_zero.output[0]],
1942+
attr={'axis': 0})
19601943
gap_neg_k_graph.add_graph_output(gap_neg_k.output[0], TensorProto.INT64, [-1])
19611944
# pad output with gap
19621945
gap_k = body_graph.make_node('If', [is_k_noneg.output[0]])

0 commit comments

Comments
 (0)