Skip to content

Commit 715c20c

Browse files
committed
Upsample has scales in attributes up until opset 8
1 parent 27c2534 commit 715c20c

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

tests/test_optimizers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,26 +1177,28 @@ def test_cast_back_to_back_non_const_mixed_types(self):
11771177
self.run_and_compare(["res", "res2", "res3"], {"u": np.random.randn(1, 2, 3).astype(np.float32)}, model_proto,
11781178
"Cast", 5)
11791179

1180+
@check_opset_max_version(8, "until opset 8 scales is in attributes")
11801181
def test_upsample_all_ones_removed(self):
1182+
shape = (1, 1, 32, 32)
11811183
node1 = helper.make_node(
1182-
"Upsample",
1183-
["X"],
1184-
["Y"],
1184+
op_type="Upsample",
1185+
inputs=["X"],
1186+
outputs=["Y"],
11851187
scales=[1., 1., 1., 1.],
11861188
name="upsample1")
11871189

11881190
graph = helper.make_graph(
11891191
[node1],
11901192
"test_upsample_all_ones",
1191-
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 32, 32, 1))],
1192-
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (1, 32, 32, 1))],
1193+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
1194+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, shape)],
11931195
)
11941196

11951197
model_proto = self.make_model(graph, producer_name="onnx-tests")
11961198

11971199
self.run_and_compare(
11981200
["Y"],
1199-
{"X": np.random.randn(1, 32, 32, 1).astype(np.float32)},
1201+
{"X": np.random.randn(*shape).astype(np.float32)},
12001202
model_proto,
12011203
"Upsample",
12021204
0)

tf2onnx/optimizer/upsample_optimizer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
"""
44

55
from __future__ import unicode_literals
6+
from onnx import helper
67

78
from .optimizer_base import GraphOptimizerBase
89

@@ -14,21 +15,24 @@ class UpsampleOptimizer(GraphOptimizerBase):
1415

1516
def __init__(self): # pylint: disable=useless-super-delegation
1617
super(UpsampleOptimizer, self).__init__()
18+
self._g = None
1719

1820
def _optimize(self, graph):
1921
return self._apply_optimization(
2022
graph,
2123
self._optimize_at_current_graph_level)
2224

2325
def _optimize_at_current_graph_level(self, graph):
26+
self._g = graph
2427
# replace upsample node with all ones in scale with identity node
25-
for n in graph.get_nodes():
28+
for n in self._g.get_nodes():
2629
if n.type == "Upsample":
27-
scales = n.get_attr_value("scales")
28-
if all([s == 1 for s in scales]):
29-
n.type = "Identity"
30-
if len(n.input) > 0:
31-
n.input = [n.input[0]]
32-
self.logger.debug("replacing " + n.name +
33-
" with Identity operation")
34-
return graph
30+
# upsample in opset <=8 has scales in attributes
31+
if self._g.opset <= 8:
32+
scales = n.get_attr_value("scales")
33+
if scales and all([float(s) == 1. for s in scales]):
34+
n.type = "Identity"
35+
self.logger.debug("replacing " + n.name +
36+
" with Identity operation ")
37+
# upsample in opset > 8 has scales in input[1]
38+
return self._g

0 commit comments

Comments
 (0)