@@ -452,9 +452,8 @@ def version_11(cls, ctx, node, **kwargs):
452
452
if ctx .get_dtype (node .input [1 ]) != onnx_pb .TensorProto .INT64 :
453
453
ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
454
454
ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ])
455
- reshape = ctx .insert_new_node_on_input (node , "Reshape" , node .input [1 ])
456
455
shape_const = ctx .make_const (utils .make_name (node .name ), np .array ([- 1 ]).astype (np .int64 ))
457
- reshape . input = [ reshape .input [0 ], shape_const .name ]
456
+ ctx . insert_new_node_on_input ( node , "Reshape" , [ node .input [1 ], shape_const .name ])
458
457
459
458
origin_dtype = ctx .get_dtype (node .output [0 ])
460
459
if origin_dtype not in [TensorProto .FLOAT , TensorProto .DOUBLE ,
@@ -715,8 +714,12 @@ def version_7(cls, ctx, node, **kwargs):
715
714
# 2: "loop" to generate mask matrix: generate col or row of matrix one by one
716
715
g = ctx .create_new_graph_with_same_config ()
717
716
node_name = utils .make_name ("const_zero_bool" )
718
- const_zero_bool = ctx .make_const (name = node_name , np_val = np .array ([[0 ]]).astype (np .bool ))
719
- ctx .set_dtype (const_zero_bool .output [0 ], onnx_pb .TensorProto .BOOL )
717
+ const_zero_bool = g .make_const (name = node_name , np_val = np .array ([[0 ]]).astype (np .bool ))
718
+ g .set_dtype (const_zero_bool .output [0 ], onnx_pb .TensorProto .BOOL )
719
+
720
+ g .add_graph_input ("trip" , onnx_pb .TensorProto .INT64 , [])
721
+ g .add_graph_input ("cond" , onnx_pb .TensorProto .BOOL , [])
722
+ g .add_graph_input ("line" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
720
723
721
724
# shift right the line and add zero at the left.
722
725
new_line = g .make_node (op_type = "Concat" , inputs = [const_zero_bool .output [0 ], "line" ],
@@ -730,10 +733,6 @@ def version_7(cls, ctx, node, **kwargs):
730
733
g .make_node ("Identity" , ["line" ], outputs = ["res" ])
731
734
g .make_node ("Identity" , [slice_node ], outputs = ["line_out" ])
732
735
733
- g .add_graph_input ("trip" , onnx_pb .TensorProto .INT64 , [])
734
- g .add_graph_input ("cond" , onnx_pb .TensorProto .BOOL , [])
735
- g .add_graph_input ("line" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
736
-
737
736
g .add_graph_output ("cond_out" , onnx_pb .TensorProto .BOOL , [])
738
737
g .add_graph_output ("line_out" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
739
738
g .add_graph_output ("res" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
0 commit comments