Skip to content

Commit f850cbe

Browse files
committed
enable graph opt at backend test
1 parent b7efca1 commit f850cbe

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

tests/backend_test_base.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from common import get_test_config
1919
from tf2onnx import utils
2020
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
21+
from tf2onnx.graph import GraphUtil
2122

2223

2324
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -140,6 +141,7 @@ def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port
140141
with tf.Session() as sess:
141142
g = process_tf_graph(sess.graph, opset=self.config.opset, output_names=output_names_with_port,
142143
target=self.config.target, **process_args)
144+
g = GraphUtil.optimize_graph(g)
143145
actual = self._run_backend(g, output_names_with_port, onnx_feed_dict)
144146

145147
for expected_val, actual_val in zip(expected, actual):

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1138,7 +1138,7 @@ def create_graph_from_onnx_graph(graph_proto, opset_version=None):
11381138
for attr_name, attr_val in n.attr.items():
11391139
if attr_val.HasField('g'):
11401140
# it was assumed that the a.g has inferred shapes/dtypes.
1141-
sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g)
1141+
sub_g = GraphUtil.create_graph_from_onnx_graph(attr_val.g, opset_version)
11421142
n.set_body_graph_as_attr(attr_name, sub_g)
11431143
return g
11441144

0 commit comments

Comments
 (0)