@@ -1624,18 +1624,17 @@ def version_10(cls, ctx, node, **kwargs):
1624
1624
rv2_in_names = [node .input [0 ]]
1625
1625
1626
1626
input_shape = ctx .get_shape (node .input [0 ])
1627
+ input_rank = len (input_shape )
1628
+ input_shape_node = ctx .make_node ("Shape" , [node .input [0 ]], op_name_scope = node .name )
1629
+
1627
1630
# Make sure input shape is not None
1628
1631
utils .make_sure (input_shape is not None , "shape of {} is None" .format (node .input [0 ]))
1629
1632
1630
- input_rank = len (input_shape )
1631
-
1632
1633
rv2_node_name = node .name
1633
1634
# ReverseV2 has a single output.
1634
1635
rv2_output_dtypes = node .output_dtypes
1635
1636
rv2_output_shapes = node .output_shapes
1636
1637
1637
- const_name_root = rv2_node_name + '_Const'
1638
-
1639
1638
# Remove ReverseV2 node from graph.
1640
1639
ctx .remove_node (rv2_node_name )
1641
1640
@@ -1689,36 +1688,20 @@ def version_10(cls, ctx, node, **kwargs):
1689
1688
1690
1689
inputs = [new_node .output [0 ]]
1691
1690
1691
+ const_one_name = utils .make_name (f'const_one' )
1692
+ const_one = ctx .make_const (name = const_one_name , np_val = np .array ([1 ], dtype = np .int64 ))
1693
+ const_axis_name = utils .make_name (f'const_{ axis } ' )
1694
+ const_axis = ctx .make_const (name = const_axis_name , np_val = np .array ([axis ], dtype = np .int64 ))
1695
+
1692
1696
# Add a Constant node (seq_len) for ReverseSequence.
1693
- if ctx .opset >= 11 :
1694
- batch_shape = ctx .make_node ("Shape" , [inputs [- 1 ]])
1695
- const_one = ctx .make_const (utils .make_name (node .name + "_const_one" ), np .array ([1 ], dtype = np .int64 ))
1696
- const_two = ctx .make_const (utils .make_name (node .name + "_const_two" ), np .array ([2 ], dtype = np .int64 ))
1697
- batch_size = ctx .make_node ("Slice" ,
1698
- [batch_shape .output [0 ], const_one .output [0 ], const_two .output [0 ]])
1699
- input_shape = ctx .make_node ("Shape" , [node .input [0 ]])
1700
- const_axis = ctx .make_const (utils .make_name (node .name + "_const_axis" ),
1701
- np .array ([axis ], dtype = np .int64 ))
1702
- const_axis_next = ctx .make_const (utils .make_name (node .name + "_const_axis_next" ),
1703
- np .array ([axis + 1 ], dtype = np .int64 ))
1704
- input_axis = ctx .make_node ("Slice" ,
1705
- [input_shape .output [0 ], const_axis .output [0 ], const_axis_next .output [0 ]])
1706
- seq_array = ctx .make_node ("Expand" , [input_axis .output [0 ], batch_size .output [0 ]])
1707
- inputs .append (seq_array .output [0 ])
1708
- else :
1709
- # Index 1 for the shape should not return 0
1710
- # since the input must have rank >= 2.
1711
- rs_batch_size = ctx .get_shape (inputs [- 1 ])[1 ]
1712
- # Make sure rs_batch_size and input_shape[axis] are not -1 each
1713
- utils .make_sure (input_shape [axis ] is not - 1 \
1714
- , "shape of axis {} is unknown" .format (axis ))
1715
- utils .make_sure (rs_batch_size is not - 1 \
1716
- , "ReverseSequence batch size for axis {} is unknown" .format (axis ))
1717
- seq_list = [input_shape [axis ]] * rs_batch_size
1718
- seq_array = np .asarray (seq_list , dtype = np .int64 ) # dtype should be int64
1719
- const_seq_name = utils .make_name (const_name_root )
1720
- new_node = ctx .make_const (name = const_seq_name , np_val = seq_array )
1721
- inputs .append (new_node .output [0 ])
1697
+ # Index 1 for the shape should not return 0, since rank(input) >=2
1698
+ input_shape = ctx .make_node ("Shape" , [inputs [- 1 ]], op_name_scope = rv2_node_name )
1699
+ batch_size = ctx .make_node ("Gather" , [input_shape .output [0 ], const_one .output [0 ]],
1700
+ op_name_scope = rv2_node_name )
1701
+ axis_dim = ctx .make_node ("Gather" , [input_shape_node .output [0 ], const_axis .output [0 ]],
1702
+ op_name_scope = rv2_node_name )
1703
+ seq_array = ctx .make_node ("Expand" , [axis_dim .output [0 ], batch_size .output [0 ]])
1704
+ inputs .append (seq_array .output [0 ])
1722
1705
1723
1706
# Add a ReverseSequence node.
1724
1707
@@ -1942,21 +1925,21 @@ def version_11(cls, ctx, node, **kwargs):
1942
1925
gap_pos_k = gap_pos_k_graph .make_node ('Concat' , [const_zero .output [0 ],
1943
1926
processed_gap .output [0 ]],
1944
1927
attr = {'axis' : 0 }) \
1945
- if align .startswith ('LEFT' ) \
1946
- else gap_pos_k_graph .make_node ('Concat' , [processed_gap .output [0 ],
1947
- const_zero .output [0 ]],
1948
- attr = {'axis' : 0 })
1928
+ if align .startswith ('LEFT' ) \
1929
+ else gap_pos_k_graph .make_node ('Concat' , [processed_gap .output [0 ],
1930
+ const_zero .output [0 ]],
1931
+ attr = {'axis' : 0 })
1949
1932
gap_pos_k_graph .add_graph_output (gap_pos_k .output [0 ], TensorProto .INT64 , [- 1 ])
1950
1933
# gap_neg_k_graph
1951
1934
gap_neg_k_graph = body_graph .create_new_graph_with_same_config ()
1952
1935
gap_neg_k_graph .parent_graph = body_graph
1953
1936
gap_neg_k = gap_neg_k_graph .make_node ('Concat' , [const_zero .output [0 ],
1954
1937
processed_gap .output [0 ]],
1955
1938
attr = {'axis' : 0 }) \
1956
- if align .endswith ('LEFT' ) \
1957
- else gap_neg_k_graph .make_node ('Concat' , [processed_gap .output [0 ],
1958
- const_zero .output [0 ]],
1959
- attr = {'axis' : 0 })
1939
+ if align .endswith ('LEFT' ) \
1940
+ else gap_neg_k_graph .make_node ('Concat' , [processed_gap .output [0 ],
1941
+ const_zero .output [0 ]],
1942
+ attr = {'axis' : 0 })
1960
1943
gap_neg_k_graph .add_graph_output (gap_neg_k .output [0 ], TensorProto .INT64 , [- 1 ])
1961
1944
# pad output with gap
1962
1945
gap_k = body_graph .make_node ('If' , [is_k_noneg .output [0 ]])
0 commit comments