Skip to content

Commit 9dbd588

Browse files
committed
Fix Pad op for dynamic input (opset 11).
1 parent ef5522d commit 9dbd588

File tree

1 file changed

+10
-6
lines changed

1 file changed

+10
-6
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def version_1(cls, ctx, node, **kwargs):
286286
k_h, k_w, k_input_channels, k_channel_multiplier = kernel_shape
287287
if k_input_channels < 1:
288288
raise ValueError("input channel must be positive")
289-
k_output_channels = k_input_channels * k_channel_multiplier
289+
k_output_channels = k_input_channels * k_channel_multiplier
290290

291291
node.set_attr("kernel_shape", [k_h, k_w])
292292
strides = conv_dims_attr(node, "strides")
@@ -448,13 +448,17 @@ def version_11(cls, ctx, node, **kwargs):
448448
if mode not in [None, "constant", "reflect"]:
449449
raise ValueError(mode + " pad mode is not supported")
450450

451-
pads = node.inputs[1].get_tensor_value()
452-
pads = np.array(pads).transpose().flatten().astype(np.int64)
453-
node.inputs[1].set_tensor_value(pads)
451+
# pads must be int64
452+
if ctx.get_dtype(node.input[1]) != onnx_pb.TensorProto.INT64:
453+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=onnx_pb.TensorProto.INT64)
454+
ctx.insert_new_node_on_input(node, "Transpose", node.input[1])
455+
reshape = ctx.insert_new_node_on_input(node, "Reshape", node.input[1])
456+
shape_const = ctx.make_const(utils.make_name(node.name), np.array([-1]).astype(np.int64))
457+
reshape.input = [reshape.input[0], shape_const.name]
454458

455459
origin_dtype = ctx.get_dtype(node.output[0])
456-
if origin_dtype not in [TensorProto.FLOAT16, TensorProto.FLOAT,
457-
TensorProto.DOUBLE]:
460+
if origin_dtype not in [TensorProto.FLOAT, TensorProto.DOUBLE,
461+
TensorProto.INT32, TensorProto.INT64]:
458462
cast_node = ctx.insert_new_node_on_input(node, "Cast", node.input[0])
459463
cast_node.set_attr("to", TensorProto.FLOAT)
460464
ctx.set_dtype(cast_node.output[0], TensorProto.FLOAT)

0 commit comments

Comments
 (0)