Skip to content

Commit eb31fc3

Browse files
authored
Merge pull request #407 from zhijxu-MS/tmp_branch_for_PR3
fix some transpose optimizer bugs
2 parents b2d3f97 + bd79f99 commit eb31fc3

File tree

4 files changed

+29
-3
lines changed

4 files changed

+29
-3
lines changed

tests/backend_test_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _run_backend(self, g, outputs, input_dict):
9191
raise ValueError("unknown backend")
9292
return y
9393

94-
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=0.,
94+
def run_test_case(self, feed_dict, input_names_with_port, output_names_with_port, rtol=1e-07, atol=1e-5,
9595
convert_var_to_const=True, constant_fold=True, check_value=True, check_shape=False,
9696
check_dtype=True, process_args=None, onnx_feed_dict=None, graph_validator=None):
9797
# optional - passed to process_tf_graph

tests/test_backend.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,6 +1118,7 @@ def test_sign(self):
11181118
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11191119
tf.reset_default_graph()
11201120

1121+
@check_target("rs6", "onehot")
11211122
def test_onehot0(self):
11221123
x_val = np.array([0, 1, 2], dtype=np.int32)
11231124
depth = 5
@@ -1138,6 +1139,7 @@ def test_onehot1(self):
11381139
_ = tf.identity(x_, name=_TFOUTPUT)
11391140
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11401141

1142+
@check_target("rs6", "onehot")
11411143
def test_onehot2(self):
11421144
for axis in [-1, 0, 1]:
11431145
tf.reset_default_graph()
@@ -1148,6 +1150,7 @@ def test_onehot2(self):
11481150
_ = tf.identity(x_, name=_TFOUTPUT)
11491151
self._run_test_case([_OUTPUT], {_INPUT: x_val})
11501152

1153+
@check_target("rs6", "onehot")
11511154
@check_opset_min_version(9, "onehot")
11521155
def test_onehot3(self):
11531156
# rank 1

tests/test_optimizers.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def run_and_compare(self, output_names_with_port, onnx_feed_dict, origin_proto,
4545
raise ValueError("only onnxruntime is supported to test transpose optimizer")
4646

4747
for expected_val, actual_val in zip(expected, actual):
48-
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=0.)
48+
self.assertAllClose(expected_val, actual_val, rtol=rtol, atol=1e-5)
4949
self.assertEqual(expected_val.dtype, actual_val.dtype)
5050
self.assertEqual(expected_val.shape, actual_val.shape)
5151

@@ -147,6 +147,21 @@ def test_transpose_with_shape(self):
147147
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
148148
model_proto, remaining_transpose_num=0)
149149

150+
def test_transpose_with_identity(self):
151+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
152+
node2 = helper.make_node("Identity", ["Y"], ["Z"], name="identity")
153+
154+
graph = helper.make_graph(
155+
[node1, node2],
156+
"transpose_with_identity",
157+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
158+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (2, 4, 5, 3))],
159+
)
160+
161+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
162+
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
163+
model_proto, remaining_transpose_num=1)
164+
150165
# Tranpose Optimizer Tests End
151166

152167
# Identity Optimizer Tests Start

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,13 @@ def _transpose_handler(self, trans, node):
340340
ops = self._g.get_nodes()
341341
self._g.replace_all_inputs(ops, node.output[0], trans.input[0])
342342

343+
shape = self._g.get_shape(node.output[0])
344+
dtype = self._g.get_dtype(node.output[0])
343345
self._g.remove_node(trans.name)
344346
self._g.remove_node(node.name)
347+
if node.output[0] in self._g.outputs:
348+
self._g.make_node("Identity", [trans.input[0]],
349+
outputs=node.output, shapes=[shape], dtypes=[dtype])
345350
return True
346351
return False
347352

@@ -380,7 +385,8 @@ def _mul_handler(self, trans, node):
380385
multiplier_input_id = i
381386
multiplier_input_node = input_node
382387

383-
if not multiplier_input_node.is_const():
388+
# node's inputs may come from one same node. if so the multiplier_input_node may be none
389+
if multiplier_input_node is None or not multiplier_input_node.is_const():
384390
return False
385391
multiplier = multiplier_input_node.get_tensor_value(as_list=False)
386392

@@ -408,6 +414,8 @@ def _mul_handler(self, trans, node):
408414
return False
409415

410416
def _identity_handler(self, trans, node):
417+
if node.output[0] in self._g.outputs:
418+
return False
411419
ops = self._g.get_nodes()
412420
self._g.replace_all_inputs(ops, node.output[0], trans.output[0])
413421
self._g.remove_node(node.name)

0 commit comments

Comments
 (0)