@@ -314,7 +314,7 @@ def version_1(cls, ctx, node, **kwargs):
314
314
k_h , k_w , k_input_channels , k_channel_multiplier = kernel_shape
315
315
if k_input_channels < 1 :
316
316
raise ValueError ("input channel must be positive" )
317
- k_output_channels = k_input_channels * k_channel_multiplier
317
+ k_output_channels = k_input_channels * k_channel_multiplier
318
318
319
319
node .set_attr ("kernel_shape" , [k_h , k_w ])
320
320
strides = conv_dims_attr (node , "strides" )
@@ -476,13 +476,16 @@ def version_11(cls, ctx, node, **kwargs):
476
476
if mode not in [None , "constant" , "reflect" ]:
477
477
raise ValueError (mode + " pad mode is not supported" )
478
478
479
- pads = node .inputs [1 ].get_tensor_value ()
480
- pads = np .array (pads ).transpose ().flatten ().astype (np .int64 )
481
- node .inputs [1 ].set_tensor_value (pads )
479
+ # pads must be int64.
480
+ if ctx .get_dtype (node .input [1 ]) != onnx_pb .TensorProto .INT64 :
481
+ ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
482
+ ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ])
483
+ shape_const = ctx .make_const (utils .make_name (node .name ), np .array ([- 1 ]).astype (np .int64 ))
484
+ ctx .insert_new_node_on_input (node , "Reshape" , [node .input [1 ], shape_const .name ])
482
485
483
486
origin_dtype = ctx .get_dtype (node .output [0 ])
484
- if origin_dtype not in [TensorProto .FLOAT16 , TensorProto .FLOAT ,
485
- TensorProto .DOUBLE ]:
487
+ if origin_dtype not in [TensorProto .FLOAT , TensorProto .DOUBLE ,
488
+ TensorProto .INT32 , TensorProto . INT64 ]:
486
489
cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
487
490
cast_node .set_attr ("to" , TensorProto .FLOAT )
488
491
ctx .set_dtype (cast_node .output [0 ], TensorProto .FLOAT )
@@ -739,8 +742,12 @@ def version_7(cls, ctx, node, **kwargs):
739
742
# 2: "loop" to generate mask matrix: generate col or row of matrix one by one
740
743
g = ctx .create_new_graph_with_same_config ()
741
744
node_name = utils .make_name ("const_zero_bool" )
742
- const_zero_bool = ctx .make_const (name = node_name , np_val = np .array ([[0 ]]).astype (np .bool ))
743
- ctx .set_dtype (const_zero_bool .output [0 ], onnx_pb .TensorProto .BOOL )
745
+ const_zero_bool = g .make_const (name = node_name , np_val = np .array ([[0 ]]).astype (np .bool ))
746
+ g .set_dtype (const_zero_bool .output [0 ], onnx_pb .TensorProto .BOOL )
747
+
748
+ g .add_graph_input ("trip" , onnx_pb .TensorProto .INT64 , [])
749
+ g .add_graph_input ("cond" , onnx_pb .TensorProto .BOOL , [])
750
+ g .add_graph_input ("line" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
744
751
745
752
# shift right the line and add zero at the left.
746
753
new_line = g .make_node (op_type = "Concat" , inputs = [const_zero_bool .output [0 ], "line" ],
@@ -754,10 +761,6 @@ def version_7(cls, ctx, node, **kwargs):
754
761
g .make_node ("Identity" , ["line" ], outputs = ["res" ])
755
762
g .make_node ("Identity" , [slice_node ], outputs = ["line_out" ])
756
763
757
- g .add_graph_input ("trip" , onnx_pb .TensorProto .INT64 , [])
758
- g .add_graph_input ("cond" , onnx_pb .TensorProto .BOOL , [])
759
- g .add_graph_input ("line" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
760
-
761
764
g .add_graph_output ("cond_out" , onnx_pb .TensorProto .BOOL , [])
762
765
g .add_graph_output ("line_out" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
763
766
g .add_graph_output ("res" , onnx_pb .TensorProto .BOOL , [- 1 , - 1 ])
0 commit comments