Skip to content

Commit 9c48cfa

Browse files
authored
Merge pull request #624 from zhijxu-MS/enhancement
make convert.py has same procedure with run_pretrained_models.py
2 parents bfb0644 + 8eadd4e commit 9c48cfa

File tree

2 files changed

+4
-7
lines changed

2 files changed

+4
-7
lines changed

tf2onnx/convert.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
import tensorflow as tf
1616

17-
from tf2onnx.graph import GraphUtil
1817
from tf2onnx.tfonnx import process_tf_graph, tf_optimize
19-
from tf2onnx import constants, loader, logging, utils
18+
from tf2onnx import constants, loader, logging, utils, optimizer
2019

2120

2221
# pylint: disable=unused-argument
@@ -145,10 +144,8 @@ def main():
145144
output_names=outputs,
146145
inputs_as_nchw=args.inputs_as_nchw)
147146

148-
model_proto = g.make_model("converted from {}".format(model_path))
149-
150-
logger.info("")
151-
model_proto = GraphUtil.optimize_model_proto(model_proto)
147+
onnx_graph = optimizer.optimize_graph(g)
148+
model_proto = onnx_graph.make_model("converted from {}".format(model_path))
152149

153150
# write onnx graph
154151
logger.info("")

tf2onnx/optimizer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def optimize_graph(graph):
4747
logger.verbose("Apply %s", name)
4848
current = copy.deepcopy(graph)
4949
opt = factory()
50-
graph = opt.optimize(current)
50+
graph = opt.optimize(current) or graph
5151
continue_flag = continue_flag or opt.graph_been_opt
5252

5353
except Exception: # pylint: disable=broad-except

0 commit comments

Comments
 (0)