@@ -286,7 +286,7 @@ def version_1(cls, ctx, node, **kwargs):
286
286
k_h , k_w , k_input_channels , k_channel_multiplier = kernel_shape
287
287
if k_input_channels < 1 :
288
288
raise ValueError ("input channel must be positive" )
289
- k_output_channels = k_input_channels * k_channel_multiplier
289
+ k_output_channels = k_input_channels * k_channel_multiplier
290
290
291
291
node .set_attr ("kernel_shape" , [k_h , k_w ])
292
292
strides = conv_dims_attr (node , "strides" )
@@ -448,13 +448,17 @@ def version_11(cls, ctx, node, **kwargs):
448
448
if mode not in [None , "constant" , "reflect" ]:
449
449
raise ValueError (mode + " pad mode is not supported" )
450
450
451
- pads = node .inputs [1 ].get_tensor_value ()
452
- pads = np .array (pads ).transpose ().flatten ().astype (np .int64 )
453
- node .inputs [1 ].set_tensor_value (pads )
451
+ # pads must be int64
452
+ if ctx .get_dtype (node .input [1 ]) != onnx_pb .TensorProto .INT64 :
453
+ ctx .insert_new_node_on_input (node , "Cast" , node .input [1 ], to = onnx_pb .TensorProto .INT64 )
454
+ ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ])
455
+ reshape = ctx .insert_new_node_on_input (node , "Reshape" , node .input [1 ])
456
+ shape_const = ctx .make_const (utils .make_name (node .name ), np .array ([- 1 ]).astype (np .int64 ))
457
+ reshape .input = [reshape .input [0 ], shape_const .name ]
454
458
455
459
origin_dtype = ctx .get_dtype (node .output [0 ])
456
- if origin_dtype not in [TensorProto .FLOAT16 , TensorProto .FLOAT ,
457
- TensorProto .DOUBLE ]:
460
+ if origin_dtype not in [TensorProto .FLOAT , TensorProto .DOUBLE ,
461
+ TensorProto .INT32 , TensorProto . INT64 ]:
458
462
cast_node = ctx .insert_new_node_on_input (node , "Cast" , node .input [0 ])
459
463
cast_node .set_attr ("to" , TensorProto .FLOAT )
460
464
ctx .set_dtype (cast_node .output [0 ], TensorProto .FLOAT )
0 commit comments