Skip to content

Commit bedebea

Browse files
committed
deprecate verbose argument for process_tf_graph
1 parent a20e73a commit bedebea

File tree

5 files changed

+24
-18
lines changed

5 files changed

+24
-18
lines changed

tests/run_pretrained_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def run_tensorflow(self, sess, inputs):
153153

154154
def to_onnx(self, tf_graph, opset=None, extra_opset=None, shape_override=None, input_names=None):
155155
"""Convert graph to tensorflow."""
156-
return process_tf_graph(tf_graph, continue_on_error=False, verbose=True, opset=opset,
156+
return process_tf_graph(tf_graph, continue_on_error=False, opset=opset,
157157
extra_opset=extra_opset, target=Test.target, shape_override=shape_override,
158158
input_names=input_names, output_names=self.output_names)
159159

tf2onnx/convert.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ def main():
112112
with tf.Session(graph=tf_graph):
113113
g = process_tf_graph(tf_graph,
114114
continue_on_error=args.continue_on_error,
115-
verbose=args.verbose,
116115
target=args.target,
117116
opset=args.opset,
118117
custom_op_handlers=custom_ops,

tf2onnx/rewriter/custom_rnn_rewriter.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,21 @@
77

88
from __future__ import division
99
from __future__ import print_function
10+
1011
import logging
1112
import sys
1213
import traceback
14+
1315
from onnx import onnx_pb
1416
import numpy as np
17+
1518
from tf2onnx.rewriter.loop_rewriter_base import LoopRewriterBase, Context
1619
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT, get_rnn_scope_name, parse_rnn_loop
17-
from tf2onnx.tfonnx import utils
18-
19-
20+
from tf2onnx import utils
2021

2122
logger = logging.getLogger(__name__)
2223

24+
2325
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test,broad-except,protected-access
2426

2527

@@ -159,7 +161,6 @@ def _connect_scan_with_output(self, context, scan_node):
159161
self.g.replace_all_inputs(self.g.get_nodes(), out_tensor_value_info.id, scan_node.output[index])
160162
index += 1
161163

162-
163164
def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_output=False):
164165
nodes_to_add = []
165166
shape_node = self.g.make_node("Shape", [input_id])
@@ -198,7 +199,7 @@ def _adapt_scan_sequence_input_or_output(self, target_name, input_id, handle_out
198199
else:
199200
# add a fake batch size : 1
200201
fake_batch_size_node = self.g.make_const(utils.make_name(target_name + "_target_shape"),
201-
np.array([1,], dtype=np.int64))
202+
np.array([1], dtype=np.int64))
202203
nodes_to_add.append(fake_batch_size_node)
203204
new_shape_node = self.g.make_node("Concat",
204205
[fake_batch_size_node.output[0], shape_node.output[0]],

tf2onnx/rewriter/loop_rewriter.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,17 @@
77

88
from __future__ import division
99
from __future__ import print_function
10+
1011
import logging
1112
import sys
1213
import traceback
14+
1315
from onnx import TensorProto
1416
import numpy as np
17+
1518
from tf2onnx.rewriter.loop_rewriter_base import LoopRewriterBase, Context
1619
from tf2onnx.rewriter.rnn_utils import REWRITER_RESULT
17-
from tf2onnx.tfonnx import utils
18-
19-
20+
from tf2onnx import utils
2021

2122
logger = logging.getLogger(__name__)
2223

tf2onnx/tfonnx.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from __future__ import unicode_literals
1111

1212
import collections
13-
import logging
1413
import sys
1514
import traceback
1615

@@ -23,12 +22,12 @@
2322
import tf2onnx
2423
import tf2onnx.onnx_opset # pylint: disable=unused-import
2524
import tf2onnx.custom_opsets # pylint: disable=unused-import
26-
from tf2onnx import constants, schemas, utils, handler
2725
from tf2onnx.graph import Graph
2826
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2927
from tf2onnx.rewriter import * # pylint: disable=wildcard-import
3028
from tf2onnx.shape_inference import infer_shape_for_graph
3129
from tf2onnx.utils import port_name
30+
from . import constants, logging, schemas, utils, handler
3231

3332
logger = logging.getLogger(__name__)
3433

@@ -669,7 +668,7 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
669668
Args:
670669
tf_graph: tensorflow graph
671670
continue_on_error: if an op can't be processed (aka there is no mapping), continue
672-
verbose: print summary stats
671+
verbose: print summary stats (deprecated)
673672
target: list of workarounds applied to help certain platforms
674673
opset: the opset to be used (int, default is latest)
675674
custom_op_handlers: dictionary of custom ops handlers
@@ -682,6 +681,11 @@ def process_tf_graph(tf_graph, continue_on_error=False, verbose=False, target=No
682681
Return:
683682
onnx graph
684683
"""
684+
# TODO: remove verbose argument in future release
685+
if verbose:
686+
logger.warning("Argument verbose for process_tf_graph is deprecated. Please use --verbose option instead.")
687+
del verbose
688+
685689
opset = utils.find_opset(opset)
686690
print("using tensorflow={}, onnx={}, opset={}, tfonnx={}/{}".format(
687691
tf.__version__, utils.get_onnx_version(), opset,
@@ -795,10 +799,11 @@ def compat_handler(ctx, node, **kwargs):
795799

796800
g.update_proto()
797801

798-
if verbose:
799-
print("tensorflow ops: {}".format(op_cnt))
800-
print("tensorflow attr: {}".format(attr_cnt))
801-
print("onnx mapped: {}".format(mapped_op))
802-
print("onnx unmapped: {}".format(unmapped_op))
802+
logger.verbose(
803+
"Summay Stats:\n"
804+
"\ttensorflow ops: {}\n"
805+
"\ttensorflow attr: {}\n"
806+
"\tonnx mapped: {}\n"
807+
"\tonnx unmapped: {}".format(op_cnt, attr_cnt, mapped_op, unmapped_op))
803808

804809
return g

0 commit comments

Comments
 (0)