Skip to content

Commit d636d3d

Browse files
committed
fix the MirrorPad conversion failure
(cherry picked from commit 6eb1940)
1 parent 94317a3 commit d636d3d

File tree

1 file changed

+17
-1
lines changed

1 file changed

+17
-1
lines changed

tf2onnx/tfonnx.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,9 +710,23 @@ def splitv_op(ctx, node, name, args):
710710

711711

712712
def pad_op(ctx, node, name, args):
713-
# T output = Pad(T input, Tpaddings paddings, @type Tpaddings)
713+
# T output = Pad(T input, int32 paddings, @type Tpaddings), CONST model using default value
714+
# or PadV2(T input, int32 paddings, T constant_value, @type Tpaddings), CONST mode - default value specified
715+
# or MirrorPad(T input, int32 paddings, @type Tpaddings, @STRING mode), other mode.
714716
# T output = Pad(T data, @STRING mode, @INTS pads, @FLOAT value)
715717
paddings = np.array(node.inputs[1].get_tensor_value()).transpose().flatten()
718+
mode = node.get_attr("mode")
719+
if mode:
720+
mode = mode.s.decode("utf-8").lower()
721+
722+
if mode not in [None, "constant", "reflect"]:
723+
raise ValueError(mode + " pad mode is not supported")
724+
725+
if mode in [None, "constant"] and len(node.input) == 3:
726+
const_val = node.input[2]
727+
node.set_attr("value", const_val)
728+
ctx.remove_input(node, node.input[2])
729+
716730
ctx.remove_input(node, node.input[1])
717731
node.set_attr("pads", paddings)
718732
return node
@@ -1063,11 +1077,13 @@ def fused_batchnorm_op7(ctx, node, name, args):
10631077
"Mean": (reduce_op, ["ReduceMean"]),
10641078
"Min": (reduce_op, ["ReduceMin"]),
10651079
"Minimum": (minmax_op, ["Min"]),
1080+
"MirrorPad": (pad_op, ["Pad"]),
10661081
"Mul": (broadcast_op, []),
10671082
"Neg": (direct_op, []),
10681083
"NoOp": (no_op, []),
10691084
"NotEqual": (direct_op, ["Not"]),
10701085
"Pad": (pad_op, []),
1086+
"PadV2": (pad_op, ["Pad"]),
10711087
"Placeholder": (placeholder_op, []),
10721088
"PlaceholderV2": (placeholder_op, []),
10731089
"PlaceholderWithDefault": (placeholder_op, []),

0 commit comments

Comments
 (0)