Skip to content

Commit 5ced2e6

Browse files
Added output_frozen_graph flag and marked --fold_const as deprecated
Signed-off-by: Tom Wildenhain <[email protected]>
1 parent 9e8e42d commit 5ced2e6

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

README.md

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,10 @@ python -m tf2onnx.convert
145145
[--target TARGET]
146146
[--custom-ops list-of-custom-ops]
147147
[--fold_const]
148+
[--large_model]
148149
[--continue_on_error]
149150
[--verbose]
151+
[--output_frozen_graph]
150152
```
151153

152154
### Parameters
@@ -199,13 +201,17 @@ Only valid with parameter `--saved_model`. If a model contains a list of concret
199201

200202
Only valid with parameter `--saved_model`. When set, creates a zip file containing the ONNX protobuf model and large tensor values stored externally. This allows for converting models that exceed the 2 GB protobuf limit.
201203

204+
#### --output_frozen_graph
205+
206+
Saves the frozen tensorflow graph to file.
207+
202208
#### --target
203209

204210
Some models require special handling to run on some runtimes. In particular, the model may use unsupported data types. Workarounds are activated with ```--target TARGET```. Currently supported values are listed on this [wiki](https://github.com/onnx/tensorflow-onnx/wiki/target). If your model will be run on Windows ML, you should specify the appropriate target value.
205211

206212
#### --fold_const
207213

208-
When set, TensorFlow fold_constants transformation is applied before conversion. This benefits features including Transpose optimization (e.g. Transpose operations introduced during tf-graph-to-onnx-graph conversion will be removed), and RNN unit conversion (for example LSTM). Older TensorFlow version might run into issues with this option depending on the model.
214+
Deprecated. Constant folding is always enabled.
209215

210216
### <a name="summarize_graph"></a>Tool to get Graph Inputs & Outputs
211217

tests/test_convert.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,18 @@
1010
from tf2onnx import convert
1111
from common import check_tf_min_version
1212

13-
def run_test_case(args):
13+
def run_test_case(args, paths_to_check=None):
1414
""" run case and clean up """
15+
if paths_to_check is None:
16+
paths_to_check = [args[-1]]
1517
sys.argv = args
1618
convert.main()
17-
ret = os.path.exists(args[-1])
18-
if ret:
19-
os.remove(args[-1])
19+
ret = True
20+
for p in paths_to_check:
21+
if os.path.exists(p):
22+
os.remove(p)
23+
else:
24+
ret = False
2025
return ret
2126

2227

@@ -33,6 +38,20 @@ def test_convert_saved_model(self):
3338
'--output',
3439
'converted_saved_model.onnx']))
3540

41+
def test_convert_output_frozen_graph(self):
42+
""" convert saved model """
43+
self.assertTrue(run_test_case(['',
44+
'--saved-model',
45+
'tests/models/regression/saved_model',
46+
'--tag',
47+
'serve',
48+
'--output',
49+
'converted_saved_model.onnx',
50+
'--output_frozen_graph',
51+
'frozen_graph.pb'
52+
],
53+
paths_to_check=['converted_saved_model.onnx', 'frozen_graph.pb']))
54+
3655
@check_tf_min_version("2.2")
3756
def test_convert_large_model(self):
3857
""" convert saved model to onnx large model format """

tf2onnx/convert.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def get_args():
6868
parser.add_argument("--continue_on_error", help="continue_on_error", action="store_true")
6969
parser.add_argument("--verbose", "-v", help="verbose output, option is additive", action="count")
7070
parser.add_argument("--debug", help="debug mode", action="store_true")
71-
parser.add_argument("--fold_const", help="enable tf constant_folding transformation before conversion",
71+
parser.add_argument("--output_frozen_graph", help="output frozen tf graph to file")
72+
parser.add_argument("--fold_const", help="Deprecated. Constant folding is always enabled.",
7273
action="store_true")
7374
# experimental
7475
parser.add_argument("--inputs-as-nchw", help="transpose inputs as from nhwc to nchw")
@@ -148,6 +149,8 @@ def main():
148149
const_node_values = None
149150
if args.large_model:
150151
const_node_values = compress_graph_def(graph_def)
152+
if args.output_frozen_graph:
153+
utils.save_protobuf(args.output_frozen_graph, graph_def)
151154
tf.import_graph_def(graph_def, name='')
152155
with tf_loader.tf_session(graph=tf_graph):
153156
g = process_tf_graph(tf_graph,

0 commit comments

Comments
 (0)