Skip to content

Commit 26cf48d

Browse files
authored
Merge pull request #92 from pengwa/fix-mirror-pad
fix the MirrorPad conversion failure
2 parents 94317a3 + e8f222d commit 26cf48d

File tree

2 files changed

+42
-6
lines changed

2 files changed

+42
-6
lines changed

tests/test_graph.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@
2020
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
2121

2222

23-
def onnx_to_graphviz(g):
23+
def onnx_to_graphviz(g, include_attrs = False):
2424
g2 = gv.Digraph()
2525
for node in g.get_nodes():
2626
kwarg = {}
2727
attr = node.attr
28-
if "shape" in attr:
29-
kwarg["shape"] = str(attr["shape"].ints)
30-
if "broadcast" in attr:
31-
kwarg["broadcast"] = str(attr["broadcast"].i)
28+
if include_attrs:
29+
for a in attr:
30+
kwarg[a] = str(helper.get_attribute_value(attr[a]))
31+
else:
32+
if "shape" in attr:
33+
kwarg["shape"] = str(attr["shape"].ints)
34+
if "broadcast" in attr:
35+
kwarg["broadcast"] = str(attr["broadcast"].i)
36+
3237
g2.node(node.name, op_type=node.type, **kwarg)
3338
for node in g.get_nodes():
3439
for i in node.input:
@@ -287,6 +292,21 @@ def print_handler(ctx, node, name, args):
287292
'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
288293
onnx_to_graphviz(g))
289294

295+
def test_pad(self):
296+
with tf.Session() as sess:
297+
t = tf.constant([[1, 2, 3], [4, 5, 6]], name= "input1")
298+
paddings = tf.constant([[1, 1,], [2, 2]], name="paddings")
299+
a = tf.pad(t, paddings, "CONSTANT", "const_no_val")
300+
b = tf.pad(t, paddings, "CONSTANT", "const_with_val", 999)
301+
c= tf.pad(t, paddings, "REFLECT", "reflect")
302+
g = process_tf_graph(sess.graph)
303+
304+
self.assertEqual('digraph { const_no_val [op_type=Pad pads="[1, 2, 1, 2]"]'
305+
' const_with_val [op_type=Pad pads="[1, 2, 1, 2]" value=999]'
306+
' reflect [mode="b\'reflect\'" op_type=Pad pads="[1, 2, 1, 2]"]'
307+
' input1:0 -> const_no_val input1:0 -> const_with_val input1:0 -> reflect }',
308+
onnx_to_graphviz(g, True))
309+
290310

291311
if __name__ == '__main__':
292312
unittest.main()

tf2onnx/tfonnx.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -710,9 +710,23 @@ def splitv_op(ctx, node, name, args):
710710

711711

712712
def pad_op(ctx, node, name, args):
713-
# T output = Pad(T input, Tpaddings paddings, @type Tpaddings)
713+
# T output = Pad(T input, int32 paddings, @type Tpaddings), CONST model using default value
714+
# or PadV2(T input, int32 paddings, T constant_value, @type Tpaddings), CONST mode - default value specified
715+
# or MirrorPad(T input, int32 paddings, @type Tpaddings, @STRING mode), other mode.
714716
# T output = Pad(T data, @STRING mode, @INTS pads, @FLOAT value)
715717
paddings = np.array(node.inputs[1].get_tensor_value()).transpose().flatten()
718+
mode = node.get_attr("mode")
719+
if mode:
720+
mode = mode.s.decode("utf-8").lower()
721+
node.set_attr("mode", mode)
722+
if mode not in [None, "constant", "reflect"]:
723+
raise ValueError(mode + " pad mode is not supported")
724+
725+
if mode in [None, "constant"] and len(node.input) == 3:
726+
const_val = node.inputs[2].get_tensor_value()[0]
727+
node.set_attr("value", const_val)
728+
ctx.remove_input(node, node.input[2])
729+
716730
ctx.remove_input(node, node.input[1])
717731
node.set_attr("pads", paddings)
718732
return node
@@ -1063,11 +1077,13 @@ def fused_batchnorm_op7(ctx, node, name, args):
10631077
"Mean": (reduce_op, ["ReduceMean"]),
10641078
"Min": (reduce_op, ["ReduceMin"]),
10651079
"Minimum": (minmax_op, ["Min"]),
1080+
"MirrorPad": (pad_op, ["Pad"]),
10661081
"Mul": (broadcast_op, []),
10671082
"Neg": (direct_op, []),
10681083
"NoOp": (no_op, []),
10691084
"NotEqual": (direct_op, ["Not"]),
10701085
"Pad": (pad_op, []),
1086+
"PadV2": (pad_op, ["Pad"]),
10711087
"Placeholder": (placeholder_op, []),
10721088
"PlaceholderV2": (placeholder_op, []),
10731089
"PlaceholderWithDefault": (placeholder_op, []),

0 commit comments

Comments
 (0)