Skip to content

Commit b182821

Browse files
committed
add test for transpose opt's enhancement: replacing trans with reshape
1 parent c456a14 commit b182821

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

tests/test_optimizers.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,43 @@ def test_trans_output_as_graph_outputs(self):
188188

189189
self.assertTrue(trans_cnt == 1, msg="Expect 1 Transpose ops left, but actually " + str(trans_cnt) + " left")
190190

191+
def test_trans_can_be_replaced_with_reshape1(self):
192+
# test trans-NHWC
193+
input_shapes_np = [(2, 3, 4, 1), (2, 1, 1, 4), (2, 3, 4, 1)]
194+
input_shapes = [(2, 3, 4, 1), (2, 1, 1, 4), (2, -1, -1, 1)]
195+
perm = (0, 3, 1, 2)
196+
for input_shape_np, input_shape in zip(input_shapes_np, input_shapes):
197+
result_shape = [input_shape[i] for i in perm]
198+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
199+
graph = helper.make_graph(
200+
[node1],
201+
"test_trans_can_be_replaced_with_reshape",
202+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
203+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, result_shape)],
204+
)
205+
206+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
207+
self.run_transpose_compare(["Y"], {"X": np.random.randn(*input_shape_np).astype(np.float32)},
208+
model_proto, remaining_transpose_num=0)
209+
210+
def test_trans_can_be_replaced_with_reshape2(self):
211+
# test trans-NCHW
212+
input_shapes_np = [(2, 1, 3, 4), (2, 4, 1, 1), (2, 1, 3, 4)]
213+
input_shapes = [(2, 1, 3, 4), (2, 4, 1, 1), (2, 1, -1, -1)]
214+
perm = (0, 2, 3, 1)
215+
for input_shape_np, input_shape in zip(input_shapes_np, input_shapes):
216+
result_shape = [input_shape[i] for i in perm]
217+
node1 = helper.make_node("Transpose", ["X"], ["Y"], perm=perm, name="trans")
218+
graph = helper.make_graph(
219+
[node1],
220+
"test_trans_can_be_replaced_with_reshape",
221+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape)],
222+
[helper.make_tensor_value_info("Y", TensorProto.FLOAT, result_shape)],
223+
)
224+
225+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
226+
self.run_transpose_compare(["Y"], {"X": np.random.randn(*input_shape_np).astype(np.float32)},
227+
model_proto, remaining_transpose_num=0)
191228
# Tranpose Optimizer Tests End
192229

193230
# Identity Optimizer Tests Start

0 commit comments

Comments
 (0)