Skip to content

Commit 2594eca

Browse files
committed
refactor
1 parent 5617517 commit 2594eca

File tree

2 files changed

+16
-0
lines changed

2 files changed

+16
-0
lines changed

tests/test_optimizers.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,21 @@ def test_transpose_with_shape(self):
147147
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
148148
model_proto, remaining_transpose_num=0)
149149

150+
def test_transpose_with_identity(self):
151+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans")
152+
node2 = helper.make_node("Identity", ["Y"], ["Z"], name="identity")
153+
154+
graph = helper.make_graph(
155+
[node1, node2],
156+
"transpose_with_identity",
157+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (2, 3, 4, 5))],
158+
[helper.make_tensor_value_info("Z", TensorProto.FLOAT, (2, 4, 5, 3))],
159+
)
160+
161+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
162+
self.run_transpose_compare(["Z"], {"X": np.random.randn(2, 3, 4, 5).astype(np.float32)},
163+
model_proto, remaining_transpose_num=1)
164+
150165
# Tranpose Optimizer Tests End
151166

152167
# Identity Optimizer Tests Start

tf2onnx/optimizer/transpose_optimizer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def _mul_handler(self, trans, node):
385385
multiplier_input_id = i
386386
multiplier_input_node = input_node
387387

388+
# node's inputs may come from one same node. if so the multiplier_input_node may be none
388389
if multiplier_input_node is None or not multiplier_input_node.is_const():
389390
return False
390391
multiplier = multiplier_input_node.get_tensor_value(as_list=False)

0 commit comments

Comments
 (0)