Skip to content

Commit cb3f061

Browse files
committed
Disconnected inputs like comment pointed out, added basic test
1 parent c4c46dc commit cb3f061

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

tests/test_optimizers.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1177,6 +1177,25 @@ def test_cast_back_to_back_non_const_mixed_types(self):
11771177
self.run_and_compare(["res", "res2", "res3"], {"u": np.random.randn(1, 2, 3).astype(np.float32)}, model_proto,
11781178
"Cast", 5)
11791179

1180+
def test_upsample_all_ones_removed(self):
1181+
node1 = helper.make_node("Upsample", ["X"], ["Y"], scales=[1, 1, 1, 1], name="upsample1")
1182+
1183+
graph = helper.make_graph(
1184+
[node1],
1185+
"test_upsample_all_ones",
1186+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (32, 16))],
1187+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, (32, 16))],
1188+
)
1189+
1190+
model_proto = self.make_model(graph, producer_name="onnx-tests")
1191+
1192+
self.run_and_compare(
1193+
["Y"],
1194+
{"X": np.random.randn(32, 16).astype(np.float32)},
1195+
model_proto,
1196+
"Upsample",
1197+
0)
1198+
11801199

11811200
if __name__ == "__main__":
11821201
unittest_main()

tf2onnx/optimizer/upsample_optimizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111

1212
class UpsampleOptimizer(GraphOptimizerBase):
13-
"""Resize Optimizer."""
13+
"""Upsample Optimizer."""
1414

1515
def __init__(self): # pylint: disable=useless-super-delegation
1616
super(UpsampleOptimizer, self).__init__()
@@ -21,12 +21,14 @@ def _optimize(self, graph):
2121
self._optimize_at_current_graph_level)
2222

2323
def _optimize_at_current_graph_level(self, graph):
24-
# replace resize operations with all ones in scale with Identity nodes
24+
# replace upsample node with all ones in scale with identity node
2525
for n in graph.get_nodes():
2626
if n.type == "Upsample":
2727
scales = n.get_attr_value("scales")
2828
if all([s == 1 for s in scales]):
2929
n.type = "Identity"
30+
if len(n.input) > 0:
31+
n.input = [n.input[0]]
3032
self.logger.debug("replacing " + n.name +
3133
" with Identity operation")
3234
return graph

0 commit comments

Comments
 (0)