Skip to content

Commit 155f542

Browse files
Merge pull request #1234 from onnx/tom/StrictOptimizers
Made unittests require optimizers to not fail
2 parents 9759d70 + c4ebecc commit 155f542

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

tests/backend_test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def run_test_case(self, func, feed_dict, input_names_with_port, output_names_wit
190190
const_node_values=const_node_values,
191191
initialized_tables=initialized_tables,
192192
**process_args)
193-
g = optimizer.optimize_graph(g)
193+
g = optimizer.optimize_graph(g, catch_errors=False)
194194
actual = self.run_backend(g, output_names_with_port, onnx_feed_dict, large_model)
195195

196196
for expected_val, actual_val in zip(expected, actual):

tf2onnx/optimizer/__init__.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def _get_optimizers():
3636
return _optimizers
3737

3838

39-
def optimize_graph(graph):
40-
""" Optimize graph, return optimized graph. No throw. """
39+
def optimize_graph(graph, catch_errors=True):
40+
""" Optimize graph, return optimized graph. No throw if catch_errors is true"""
4141
logger = logging.getLogger(__name__)
4242
logger.info("Optimizing ONNX model")
4343

@@ -47,17 +47,21 @@ def optimize_graph(graph):
4747
while continue_flag:
4848
continue_flag = False
4949
for name, factory in opts.items():
50-
try:
51-
logger.verbose("Apply %s", name)
52-
current = copy.deepcopy(graph)
50+
logger.verbose("Apply %s", name)
51+
if catch_errors:
52+
try:
53+
current = copy.deepcopy(graph)
54+
opt = factory()
55+
graph = opt.optimize(current) or graph
56+
continue_flag = continue_flag or opt.graph_been_opt
57+
except Exception: # pylint: disable=broad-except
58+
# if current optimizer fails, continue with other optimizers
59+
logger.warning("Failed to apply %s", name, exc_info=1)
60+
else:
5361
opt = factory()
54-
graph = opt.optimize(current) or graph
62+
graph = opt.optimize(graph)
5563
continue_flag = continue_flag or opt.graph_been_opt
5664

57-
except Exception: # pylint: disable=broad-except
58-
# if current optimizer fails, continue with other optimizers
59-
logger.warning("Failed to apply %s", name, exc_info=1)
60-
6165
try:
6266
graph.topological_sort(graph.get_nodes())
6367
except Exception: # pylint: disable=broad-except

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,7 @@ def _split_handler(self, trans, node):
571571

572572
def _squeeze_handler(self, trans, node):
573573
def _calculate_new_attr(ori_perm, ori_squeeze_axes):
574+
ori_squeeze_axes = [i if i >= 0 else i + 4 for i in ori_squeeze_axes]
574575
new_squeeze_axes = sorted([ori_perm[i] for i in ori_squeeze_axes])
575576
# calculate output shape after trans and squeeze
576577
input_shape = "abcd"

0 commit comments

Comments
 (0)