@@ -1823,13 +1823,13 @@ def version_11(cls, ctx, node, **kwargs):
1823
1823
class MatrixDiagPartV2V3 :
1824
1824
@classmethod
1825
1825
def version_11 (cls , ctx , node , ** kwargs ):
1826
- Input = node .input [0 ]
1827
- K = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
1828
- Padding = node .input [2 ]
1829
- Align = 'LEFT_LEFT'
1826
+ input_tensor = node .input [0 ]
1827
+ k = ctx .make_node ('Cast' , [node .input [1 ]], attr = {'to' : TensorProto .INT64 }).output [0 ]
1828
+ padding = node .input [2 ]
1829
+ align = 'LEFT_LEFT'
1830
1830
if node .op .op_type == 'MatrixDiagPartV3' :
1831
- Align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
1832
- input_rank = len (ctx .get_shape (Input ))
1831
+ align = node .get_attr_str ('align' ) if 'align' in node .attr else 'LEFT_RIGHT'
1832
+ input_rank = len (ctx .get_shape (input_tensor ))
1833
1833
raw_input_shape = [- 1 ] * input_rank
1834
1834
per_loop_shape = raw_input_shape [:- 1 ]
1835
1835
raw_output_shape = raw_input_shape [:- 2 ] + [- 1 ]
@@ -1845,23 +1845,23 @@ def version_11(cls, ctx, node, **kwargs):
1845
1845
const_neg_one = ctx .make_const (utils .make_name (node .name ) + 'const_neg_one' , np .array ([- 1 ]).astype (np .int64 ))
1846
1846
const_neg_two = ctx .make_const (utils .make_name (node .name ) + 'const_neg_two' , np .array ([- 2 ]).astype (np .int64 ))
1847
1847
# prepare new_shape of input
1848
- input_shape = ctx .make_node ('Shape' , [Input ])
1848
+ input_shape = ctx .make_node ('Shape' , [input_tensor ])
1849
1849
shape_input_shape = ctx .make_node ('Shape' , [input_shape .output [0 ]])
1850
1850
matrix_shape = ctx .make_node ('Slice' ,
1851
1851
[input_shape .output [0 ], const_neg_two .output [0 ], shape_input_shape .output [0 ]])
1852
1852
min_dim = ctx .make_node ('ReduceMin' , [matrix_shape .output [0 ]])
1853
1853
input_depth = ctx .make_node ('Slice' , [matrix_shape .output [0 ], const_neg_two .output [0 ], const_neg_one .output [0 ]])
1854
1854
input_width = ctx .make_node ('Slice' , [matrix_shape .output [0 ], const_neg_one .output [0 ], const_two .output [0 ]])
1855
1855
temp_shape = ctx .make_node ('Concat' , [const_neg_one .output [0 ], matrix_shape .output [0 ]], attr = {'axis' : 0 })
1856
- temp_input = ctx .make_node ('Reshape' , [Input , temp_shape .output [0 ]])
1856
+ temp_input = ctx .make_node ('Reshape' , [input_tensor , temp_shape .output [0 ]])
1857
1857
temp_transposed = ctx .make_node ('Transpose' , [temp_input .output [0 ]], attr = {'perm' : [0 , 2 , 1 ]})
1858
1858
half_shape = ctx .make_node ('Slice' , [input_shape .output [0 ], const_zero .output [0 ], const_neg_two .output [0 ]])
1859
1859
new_shape = ctx .make_node ('Concat' , [half_shape .output [0 ], input_width .output [0 ], input_depth .output [0 ]],
1860
1860
attr = {'axis' : 0 })
1861
1861
# define body graph for main loop
1862
- k_shape = ctx .make_node ('Shape' , [K ])
1863
- k_start = ctx .make_node ('Slice' , [K , const_zero .output [0 ], const_one .output [0 ]])
1864
- k_end = ctx .make_node ('Slice' , [K , const_neg_one .output [0 ], k_shape .output [0 ]])
1862
+ k_shape = ctx .make_node ('Shape' , [k ])
1863
+ k_start = ctx .make_node ('Slice' , [k , const_zero .output [0 ], const_one .output [0 ]])
1864
+ k_end = ctx .make_node ('Slice' , [k , const_neg_one .output [0 ], k_shape .output [0 ]])
1865
1865
raw_total_k = ctx .make_node ('Sub' , [k_end .output [0 ], k_start .output [0 ]])
1866
1866
total_k = ctx .make_node ('Add' , [raw_total_k .output [0 ], const_one .output [0 ]])
1867
1867
trip_name = utils .make_name (node .name + "_i" )
@@ -1873,7 +1873,7 @@ def version_11(cls, ctx, node, **kwargs):
1873
1873
# identity of input
1874
1874
identity_input_graph = body_graph .create_new_graph_with_same_config ()
1875
1875
identity_input_graph .parent_graph = body_graph
1876
- identity_input = identity_input_graph .make_node ('Identity' , [Input ])
1876
+ identity_input = identity_input_graph .make_node ('Identity' , [input_tensor ])
1877
1877
identity_input_graph .add_graph_output (identity_input .output [0 ], ctx .get_dtype (node .input [0 ]), raw_input_shape )
1878
1878
# transposed input
1879
1879
transposed_input_graph = body_graph .create_new_graph_with_same_config ()
@@ -1927,14 +1927,14 @@ def version_11(cls, ctx, node, **kwargs):
1927
1927
gap_pos_k_graph = body_graph .create_new_graph_with_same_config ()
1928
1928
gap_pos_k_graph .parent_graph = body_graph
1929
1929
gap_pos_k = gap_pos_k_graph .make_node ('Concat' , [const_zero .output [0 ], processed_gap .output [0 ]], attr = {'axis' : 0 }) \
1930
- if Align .startswith ('LEFT' ) \
1930
+ if align .startswith ('LEFT' ) \
1931
1931
else gap_pos_k_graph .make_node ('Concat' , [processed_gap .output [0 ], const_zero .output [0 ]], attr = {'axis' : 0 })
1932
1932
gap_pos_k_graph .add_graph_output (gap_pos_k .output [0 ], TensorProto .INT64 , [- 1 ])
1933
1933
# gap_neg_k_graph
1934
1934
gap_neg_k_graph = body_graph .create_new_graph_with_same_config ()
1935
1935
gap_neg_k_graph .parent_graph = body_graph
1936
1936
gap_neg_k = gap_neg_k_graph .make_node ('Concat' , [const_zero .output [0 ], processed_gap .output [0 ]], attr = {'axis' : 0 }) \
1937
- if Align .endswith ('LEFT' ) \
1937
+ if align .endswith ('LEFT' ) \
1938
1938
else gap_neg_k_graph .make_node ('Concat' , [processed_gap .output [0 ], const_zero .output [0 ]], attr = {'axis' : 0 })
1939
1939
gap_neg_k_graph .add_graph_output (gap_neg_k .output [0 ], TensorProto .INT64 , [- 1 ])
1940
1940
# pad output with gap
@@ -1945,7 +1945,7 @@ def version_11(cls, ctx, node, **kwargs):
1945
1945
gap_right = body_graph .make_node ('Slice' , [gap_k .output [0 ], const_one .output [0 ], const_two .output [0 ]])
1946
1946
gap_all = body_graph .make_node ('Concat' , [sliced_zero .output [0 ], gap_left .output [0 ], sliced_zero .output [0 ],
1947
1947
gap_right .output [0 ]], attr = {'axis' : 0 })
1948
- padded_output = body_graph .make_node ('Pad' , [raw_output .output [0 ], gap_all .output [0 ], Padding ])
1948
+ padded_output = body_graph .make_node ('Pad' , [raw_output .output [0 ], gap_all .output [0 ], padding ])
1949
1949
cond_output = body_graph .make_node ('Identity' , [cond_name ])
1950
1950
body_graph .add_graph_output (cond_output .output [0 ], TensorProto .BOOL , [])
1951
1951
body_graph .add_graph_output (padded_output .output [0 ], ctx .get_dtype (node .input [0 ]), per_loop_shape )
0 commit comments