Skip to content

Commit 0f9933f

Browse files
resolve requests
1 parent 2808529 commit 0f9933f

File tree

8 files changed

+23
-25
lines changed

8 files changed

+23
-25
lines changed

tests/test_loops.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,11 @@
88
from __future__ import print_function
99
from __future__ import unicode_literals
1010

11-
import unittest
12-
1311
import numpy as np
1412
import tensorflow as tf
1513

1614
from backend_test_base import Tf2OnnxBackendTestBase
17-
from common import unittest_main, check_tf_min_version
15+
from common import unittest_main, check_tf_min_version, check_onnxruntime_min_version
1816

1917

2018
# pylint: disable=missing-docstring,invalid-name,unused-argument,using-constant-test
@@ -162,7 +160,8 @@ def b(i, out_ta):
162160
output_names_with_port = ["i:0", "output_ta:0"]
163161
self.run_test_case(feed_dict, input_names_with_port, output_names_with_port, rtol=1e-06)
164162

165-
@unittest.skip(
163+
@check_onnxruntime_min_version(
164+
"0.5.0",
166165
"disable this case due to onnxruntime loop issue: https://github.com/microsoft/onnxruntime/issues/1272"
167166
)
168167
def test_while_loop_with_cond_init_false(self):

tf2onnx/constants.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,3 @@
3737

3838
# Environment variables
3939
ENV_TF2ONNX_DEBUG_MODE = "TF2ONNX_DEBUG_MODE"
40-
41-
# Logging level
42-
VERBOSE = 15

tf2onnx/graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ def update_node_shape_dtype(self, node, override=False):
603603
initializers = []
604604
for i, inp in enumerate(node.inputs):
605605
if inp is None:
606-
if logger.isEnabledFor(constants.VERBOSE):
606+
if logger.isEnabledFor(logging.INFO):
607607
logger.warning(
608608
"[%s] infer a inexistent node: [%s], please check the code",
609609
node.name, node.input[i]
@@ -1181,10 +1181,10 @@ def delete_unused_nodes(self, outputs_name):
11811181
body_graph.delete_unused_nodes(body_graph.outputs)
11821182
self.reset_nodes(related_nodes)
11831183

1184-
def delete_nodes_without_dependency(self, to_delete):
1185-
"""Delete nodes in `to_delete` without third-party dependency."""
1184+
def safe_remove_nodes(self, to_delete):
1185+
"""Delete nodes in `to_delete` without third-party node consuming it."""
11861186
delete_set = set(to_delete)
1187-
for n in to_delete:
1187+
for n in delete_set:
11881188
out_consumers = set()
11891189
for out in n.output:
11901190
out_consumers |= set(self.find_output_consumers(out))

tf2onnx/rewriter/leakyrelu_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def rewrite_leakyrelu(g, ops):
4343
ops.append(leakyrelu)
4444
g.replace_all_inputs(ops, max_node.output[0], leakyrelu.output[0])
4545
to_delete = [max_node, mul_node]
46-
g.delete_nodes_without_dependency(to_delete)
46+
g.safe_remove_nodes(to_delete)
4747

4848
return ops
4949

tf2onnx/rewriter/random_uniform.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def rewrite_random_uniform_fold_const(g, ops):
6363
to_delete = list(set(match.get_nodes()))
6464
new_node = create_onnx_random_uniform_op(g, tmax, tmin, ru_op, output, to_delete)
6565
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
66-
g.delete_nodes_without_dependency(to_delete)
66+
g.safe_remove_nodes(to_delete)
6767

6868
return ops
6969

tf2onnx/rewriter/thresholded_relu_rewriter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,5 +45,5 @@ def rewrite_thresholded_relu(g, ops):
4545
dtypes=[g.get_dtype(mul_node.output[0])])
4646
g.replace_all_inputs(ops, mul_node.output[0], thresholded_relu.output[0])
4747
to_delete = [cast_node, mul_node]
48-
g.delete_nodes_without_dependency(to_delete)
48+
g.safe_remove_nodes(to_delete)
4949
return ops

tf2onnx/tfonnx.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,8 @@ def rewrite_transpose(g, ops):
141141
dims = [i for i in range(len(shape) - 1, -1, -1)]
142142
output.set_attr("perm", dims)
143143
g.remove_input(output, output.input[1])
144-
to_delete = [n for n in set(match.get_nodes()) if n != output]
145-
g.delete_nodes_without_dependency(to_delete)
144+
to_delete = [n for n in match.get_nodes() if n != output]
145+
g.safe_remove_nodes(to_delete)
146146
return ops
147147

148148

@@ -174,7 +174,7 @@ def rewrite_random_normal(g, ops):
174174
attr={"shape": shape, "mean": mean, "scale": 1.0, "dtype": dtype})
175175

176176
g.replace_all_inputs(ops, output.output[0], new_node.output[0])
177-
g.delete_nodes_without_dependency(set(match.get_nodes()))
177+
g.safe_remove_nodes(match.get_nodes())
178178
return ops
179179

180180

@@ -206,7 +206,7 @@ def rewrite_dropout(g, ops):
206206
dtypes=[g.get_dtype(inputs2.input[0])]
207207
)
208208
g.replace_all_inputs(ops, outputs.output[0], new_node.output[0])
209-
g.delete_nodes_without_dependency(set(match.get_nodes()))
209+
g.safe_remove_nodes(match.get_nodes())
210210

211211
# remove dropout if its ratio is 1.0
212212
for node in g.get_nodes():
@@ -291,8 +291,8 @@ def rewrite_flatten(g, ops):
291291

292292
g.set_shape(out_name, input_shape[:-2] + [new_dim])
293293
g.replace_all_inputs(ops, reshape_node.output[0], out_name)
294-
to_delete = [n for n in set(match.get_nodes()) if n != input_node]
295-
g.delete_nodes_without_dependency(to_delete)
294+
to_delete = [n for n in match.get_nodes() if n != input_node]
295+
g.safe_remove_nodes(to_delete)
296296

297297
return ops
298298

@@ -649,7 +649,7 @@ def run_rewriters(g, funcs, continue_on_error):
649649
else:
650650
raise ex
651651

652-
if logger.isEnabledFor(constants.VERBOSE):
652+
if utils.is_debug_mode():
653653
broken_outputs = g.check_integrity()
654654
if broken_outputs:
655655
logging.error(

tf2onnx/verbose_logging.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,14 @@
1515

1616
from . import constants
1717

18-
_logging.addLevelName(constants.VERBOSE, "VERBOSE")
18+
VERBOSE = 15
19+
20+
_logging.addLevelName(VERBOSE, "VERBOSE")
1921

2022

2123
def _verbose(self, message, *args, **kwargs):
22-
if self.isEnabledFor(constants.VERBOSE):
23-
self._log(constants.VERBOSE, message, args, **kwargs) # pylint: disable=protected-access
24+
if self.isEnabledFor(VERBOSE):
25+
self._log(VERBOSE, message, args, **kwargs) # pylint: disable=protected-access
2426

2527

2628
def getLogger(name=None): # pylint: disable=invalid-name, function-redefined
@@ -45,7 +47,7 @@ def basicConfig(**kwargs): # pylint: disable=invalid-name, function-redefined
4547
set_tf_verbosity(_logging.getLogger().getEffectiveLevel())
4648

4749

48-
_LOG_LEVELS = [FATAL, ERROR, WARNING, INFO, constants.VERBOSE, DEBUG]
50+
_LOG_LEVELS = [FATAL, ERROR, WARNING, INFO, VERBOSE, DEBUG]
4951

5052

5153
def get_verbosity_level(verbosity, base_level=INFO):

0 commit comments

Comments
 (0)