Skip to content

Commit e8f222d

Browse files
committed
add ut for pad, and fix bug
1 parent d636d3d commit e8f222d

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

tests/test_graph.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,20 @@
2020
_TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
2121

2222

23-
def onnx_to_graphviz(g):
23+
def onnx_to_graphviz(g, include_attrs = False):
2424
g2 = gv.Digraph()
2525
for node in g.get_nodes():
2626
kwarg = {}
2727
attr = node.attr
28-
if "shape" in attr:
29-
kwarg["shape"] = str(attr["shape"].ints)
30-
if "broadcast" in attr:
31-
kwarg["broadcast"] = str(attr["broadcast"].i)
28+
if include_attrs:
29+
for a in attr:
30+
kwarg[a] = str(helper.get_attribute_value(attr[a]))
31+
else:
32+
if "shape" in attr:
33+
kwarg["shape"] = str(attr["shape"].ints)
34+
if "broadcast" in attr:
35+
kwarg["broadcast"] = str(attr["broadcast"].i)
36+
3237
g2.node(node.name, op_type=node.type, **kwarg)
3338
for node in g.get_nodes():
3439
for i in node.input:
@@ -287,6 +292,21 @@ def print_handler(ctx, node, name, args):
287292
'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
288293
onnx_to_graphviz(g))
289294

295+
def test_pad(self):
296+
with tf.Session() as sess:
297+
t = tf.constant([[1, 2, 3], [4, 5, 6]], name= "input1")
298+
paddings = tf.constant([[1, 1,], [2, 2]], name="paddings")
299+
a = tf.pad(t, paddings, "CONSTANT", "const_no_val")
300+
b = tf.pad(t, paddings, "CONSTANT", "const_with_val", 999)
301+
c= tf.pad(t, paddings, "REFLECT", "reflect")
302+
g = process_tf_graph(sess.graph)
303+
304+
self.assertEqual('digraph { const_no_val [op_type=Pad pads="[1, 2, 1, 2]"]'
305+
' const_with_val [op_type=Pad pads="[1, 2, 1, 2]" value=999]'
306+
' reflect [mode="b\'reflect\'" op_type=Pad pads="[1, 2, 1, 2]"]'
307+
' input1:0 -> const_no_val input1:0 -> const_with_val input1:0 -> reflect }',
308+
onnx_to_graphviz(g, True))
309+
290310

291311
if __name__ == '__main__':
292312
unittest.main()

tf2onnx/tfonnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -718,15 +718,15 @@ def pad_op(ctx, node, name, args):
718718
mode = node.get_attr("mode")
719719
if mode:
720720
mode = mode.s.decode("utf-8").lower()
721-
721+
node.set_attr("mode", mode)
722722
if mode not in [None, "constant", "reflect"]:
723723
raise ValueError(mode + " pad mode is not supported")
724724

725725
if mode in [None, "constant"] and len(node.input) == 3:
726-
const_val = node.input[2]
726+
const_val = node.inputs[2].get_tensor_value()[0]
727727
node.set_attr("value", const_val)
728728
ctx.remove_input(node, node.input[2])
729-
729+
730730
ctx.remove_input(node, node.input[1])
731731
node.set_attr("pads", paddings)
732732
return node

0 commit comments

Comments
 (0)