|
20 | 20 | _TENSORFLOW_DOMAIN = "ai.onnx.converters.tensorflow"
|
21 | 21 |
|
22 | 22 |
|
23 |
| -def onnx_to_graphviz(g): |
| 23 | +def onnx_to_graphviz(g, include_attrs = False): |
24 | 24 | g2 = gv.Digraph()
|
25 | 25 | for node in g.get_nodes():
|
26 | 26 | kwarg = {}
|
27 | 27 | attr = node.attr
|
28 |
| - if "shape" in attr: |
29 |
| - kwarg["shape"] = str(attr["shape"].ints) |
30 |
| - if "broadcast" in attr: |
31 |
| - kwarg["broadcast"] = str(attr["broadcast"].i) |
| 28 | + if include_attrs: |
| 29 | + for a in attr: |
| 30 | + kwarg[a] = str(helper.get_attribute_value(attr[a])) |
| 31 | + else: |
| 32 | + if "shape" in attr: |
| 33 | + kwarg["shape"] = str(attr["shape"].ints) |
| 34 | + if "broadcast" in attr: |
| 35 | + kwarg["broadcast"] = str(attr["broadcast"].i) |
| 36 | + |
32 | 37 | g2.node(node.name, op_type=node.type, **kwarg)
|
33 | 38 | for node in g.get_nodes():
|
34 | 39 | for i in node.input:
|
@@ -287,6 +292,21 @@ def print_handler(ctx, node, name, args):
|
287 | 292 | 'digraph { Print [op_type=Identity] output [op_type=Identity] input1:0 -> Print Print:0 -> output }',
|
288 | 293 | onnx_to_graphviz(g))
|
289 | 294 |
|
| 295 | + def test_pad(self): |
| 296 | + with tf.Session() as sess: |
| 297 | + t = tf.constant([[1, 2, 3], [4, 5, 6]], name= "input1") |
| 298 | + paddings = tf.constant([[1, 1,], [2, 2]], name="paddings") |
| 299 | + a = tf.pad(t, paddings, "CONSTANT", "const_no_val") |
| 300 | + b = tf.pad(t, paddings, "CONSTANT", "const_with_val", 999) |
| 301 | + c= tf.pad(t, paddings, "REFLECT", "reflect") |
| 302 | + g = process_tf_graph(sess.graph) |
| 303 | + |
| 304 | + self.assertEqual('digraph { const_no_val [op_type=Pad pads="[1, 2, 1, 2]"]' |
| 305 | + ' const_with_val [op_type=Pad pads="[1, 2, 1, 2]" value=999]' |
| 306 | + ' reflect [mode="b\'reflect\'" op_type=Pad pads="[1, 2, 1, 2]"]' |
| 307 | + ' input1:0 -> const_no_val input1:0 -> const_with_val input1:0 -> reflect }', |
| 308 | + onnx_to_graphviz(g, True)) |
| 309 | + |
290 | 310 |
|
291 | 311 | if __name__ == '__main__':
|
292 | 312 | unittest.main()
|
0 commit comments