Skip to content

Commit c4c46dc

Browse files
committed
Added optimization step to remove upsample layers with all ones in scale
1 parent ad7fe46 commit c4c46dc

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

tf2onnx/optimizer/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
from .transpose_optimizer import TransposeOptimizer
1616
from .loop_optimizer import LoopOptimizer
1717
from .back_to_back_optimizer import BackToBackOptimizer
18+
from .upsample_optimizer import UpsampleOptimizer
1819
from .. import logging
1920

2021
# optimizer sequence need to be considered carefully
2122
_optimizers = OrderedDict([
2223
("optimize_transpose", TransposeOptimizer),
24+
("remove_redundant_upsample", UpsampleOptimizer),
2325
("fold_constants", ConstFoldOptimizer),
2426
("loop_optimizer", LoopOptimizer),
2527
# merge_duplication should be used after optimize_transpose
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""Resize Optimizer.
2+
Replace resize operations with all ones in scale with Identity nodes
3+
"""
4+
5+
from __future__ import unicode_literals
6+
7+
from .optimizer_base import GraphOptimizerBase
8+
9+
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
10+
11+
12+
class UpsampleOptimizer(GraphOptimizerBase):
13+
"""Resize Optimizer."""
14+
15+
def __init__(self): # pylint: disable=useless-super-delegation
16+
super(UpsampleOptimizer, self).__init__()
17+
18+
def _optimize(self, graph):
19+
return self._apply_optimization(
20+
graph,
21+
self._optimize_at_current_graph_level)
22+
23+
def _optimize_at_current_graph_level(self, graph):
24+
# replace resize operations with all ones in scale with Identity nodes
25+
for n in graph.get_nodes():
26+
if n.type == "Upsample":
27+
scales = n.get_attr_value("scales")
28+
if all([s == 1 for s in scales]):
29+
n.type = "Identity"
30+
self.logger.debug("replacing " + n.name +
31+
" with Identity operation")
32+
return graph

0 commit comments

Comments
 (0)