Skip to content

Commit 6bfba19

Browse files
authored
Merge pull request #734 from jignparm/jignparm/fix_optimizer
Fix optimizer for Transpose
2 parents c2227ea + 3c34ae0 commit 6bfba19

File tree

2 files changed

+15
-14
lines changed

2 files changed

+15
-14
lines changed

tests/test_optimizers.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
from __future__ import print_function
88
from __future__ import unicode_literals
99

10-
import unittest
11-
1210
import numpy as np
1311
from onnx import helper, TensorProto, OperatorSetIdProto
1412
from tf2onnx import utils
@@ -1099,10 +1097,9 @@ def test_transpose_back_to_back_non_const(self):
10991097

11001098
model_proto = self.make_model(graph, producer_name="onnx-tests")
11011099
self.run_transpose_compare(["res"], {"u": np.random.randn(5, 5, 5, 5).astype(np.float32)},
1102-
model_proto, remaining_transpose_num=2)
1100+
model_proto, remaining_transpose_num=1)
11031101

1104-
# @check_opset_min_version(9, "string type tensor")
1105-
@unittest.skip("FIXME: disabled because of crash on linux/ortnightly")
1102+
@check_opset_min_version(9, "string type tensor")
11061103
def test_cast_back_to_back_non_const_mixed_types(self):
11071104
node0 = helper.make_node("Cast", ["u"], ["v"], to=11, name="cast_0") # double
11081105
node1 = helper.make_node("Cast", ["v"], ["w"], to=6, name="cast_1") # int32
@@ -1113,11 +1110,13 @@ def test_cast_back_to_back_non_const_mixed_types(self):
11131110
node5 = helper.make_node("Cast", ["w2"], ["res2"], to=7, name="cast_5") # int64
11141111

11151112
node6 = helper.make_node("Cast", ["x"], ["x2"], to=9, name="cast_6") # bool
1116-
node7 = helper.make_node("Cast", ["x2"], ["x3"], to=8, name="cast_7") # string
1117-
node8 = helper.make_node("Cast", ["x3"], ["res3"], to=3, name="cast_8") # int8
1113+
# TODO: uncomment below after fix
1114+
# https://github.com/microsoft/onnxruntime/issues/2338
1115+
# node7 = helper.make_node("Cast", ["x2"], ["x3"], to=8, name="cast_7") # string
1116+
node8 = helper.make_node("Cast", ["x2"], ["res3"], to=3, name="cast_8") # int8
11181117

11191118
graph = helper.make_graph(
1120-
[node0, node1, node2, node3, node4, node5, node6, node7, node8],
1119+
[node0, node1, node2, node3, node4, node5, node6, node8],
11211120
"test-cast-back-to-back-non-const",
11221121
[helper.make_tensor_value_info("u", TensorProto.FLOAT, (1, 2, 3))],
11231122
[helper.make_tensor_value_info("res", TensorProto.INT64, (1, 2, 3)),

tf2onnx/optimizer/back_to_back_optimizer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,8 @@ def _optimize_cast(g, node, consumer_nodes):
121121
g.remove_node(node.name)
122122
return q2
123123

124-
# TODO: reactivate after fixing interference with transpose_optimizer
125124
@staticmethod
126-
@_register_func("_Transpose")
125+
@_register_func("Transpose")
127126
def _optimize_transpose(g, node, consumer_nodes):
128127
t1 = list(node.get_attr('perm').ints)
129128
q2 = []
@@ -132,13 +131,16 @@ def _optimize_transpose(g, node, consumer_nodes):
132131
t2 = list(node2.get_attr('perm').ints)
133132
new_perm = [t1[i] for i in t2]
134133
# check if node2 can be removed. otherwise only update
135-
if new_perm == list(range(len(t2))) \
136-
and not set(node2.output) & set(g.outputs):
134+
if new_perm == list(range(len(t2))):
137135
# both nodes can be deleted
136+
shape = g.get_shape(node2.output[0])
137+
dtype = g.get_dtype(node2.output[0])
138138
node2_consumers = g.find_output_consumers(node2.output[0])
139-
for consumer in node2_consumers:
140-
consumer.input[0] = node.input[0]
139+
g.replace_all_inputs(node2_consumers, node2.output[0], node.input[0])
141140
g.remove_node(node2.name)
141+
if set(node2.output) & set(g.outputs):
142+
g.make_node("Identity", [node.input[0]],
143+
outputs=node2.output, shapes=[shape], dtypes=[dtype])
142144
else:
143145
node2.set_attr('perm', [t1[i] for i in t2])
144146
q2.append(node2.output[0])

0 commit comments

Comments
 (0)