Skip to content

Commit 290b157

Browse files
committed
Added optimization and test for opset>=9
1 parent b305209 commit 290b157

File tree

2 files changed

+48
-3
lines changed

2 files changed

+48
-3
lines changed

tests/test_optimizers.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,6 +1203,37 @@ def test_upsample_all_ones_removed(self):
12031203
"Upsample",
12041204
0)
12051205

1206+
@check_opset_min_version(9, ">= 9 scales is in input[1]")
1207+
def test_upsample_all_ones_removed_in_input(self):
1208+
shape = (1, 1, 32, 32)
1209+
const_tensor = helper.make_tensor(
1210+
name="S",
1211+
data_type=TensorProto.FLOAT,
1212+
dims=(1,4),
1213+
vals=np.array([1.0, 1.0, 1.0, 1.0], dtype=np.float32))
1214+
node0 = helper.make_node("Constant", [], ["S"], value=const_tensor)
1215+
node1 = helper.make_node(
1216+
op_type="Upsample",
1217+
inputs=["X", "S"],
1218+
outputs=["Y"],
1219+
name="upsample1")
1220+
1221+
graph = helper.make_graph(
1222+
[node0, node1],
1223+
"test_upsample_all_ones",
1224+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, shape)],
1225+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, shape)],
1226+
)
1227+
1228+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1229+
1230+
self.run_and_compare(
1231+
["Y"],
1232+
{"X": np.random.randn(*shape).astype(np.float32)},
1233+
model_proto,
1234+
"Upsample",
1235+
0)
1236+
12061237

12071238
if __name__ == "__main__":
12081239
unittest_main()

tf2onnx/optimizer/upsample_optimizer.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import unicode_literals
66

7+
import numpy as np
8+
79
from .optimizer_base import GraphOptimizerBase
810

911
# pylint: disable=logging-not-lazy,unused-argument,missing-docstring,unused-variable,arguments-differ
@@ -26,12 +28,24 @@ def _optimize_at_current_graph_level(self, graph):
2628
# replace upsample node with all ones in scale with identity node
2729
for n in self._g.get_nodes():
2830
if n.type == "Upsample":
31+
node_changed = False
2932
# upsample in opset <=8 has scales in attributes
3033
if self._g.opset <= 8:
3134
scales = n.get_attr_value("scales")
3235
if scales and all([float(s) == 1. for s in scales]):
3336
n.type = "Identity"
34-
self.logger.debug("replacing " + n.name +
35-
" with Identity operation ")
36-
# upsample in opset > 8 has scales in input[1]
37+
node_changed = True
38+
# upsample in opset >= 9 has scales in input[1]
39+
if self._g.opset >= 9 and len(n.input) == 2:
40+
scales_input = n.inputs[1]
41+
42+
if scales_input.is_const() and \
43+
np.all(scales_input.get_tensor_value(as_list=False) == 1.):
44+
n.type = "Identity"
45+
n.input = [n.input[0]]
46+
node_changed = True
47+
if node_changed:
48+
self.logger.debug("replacing " + n.name +
49+
" with Identity operation ")
50+
3751
return self._g

0 commit comments

Comments
 (0)