Skip to content

Commit 8b5fba2

Browse files
committed
add test case for pad and reducemean handler
1 parent ba08651 commit 8b5fba2

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

tests/test_optimizers.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,38 @@ def test_transpose_add_with_conv_2(self):
530530
"W": np.random.randn(1, 1, 3, 3).astype(np.float32)},
531531
model_proto, remaining_transpose_num=0)
532532

533+
def test_transpose_pad(self):
534+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
535+
node1 = helper.make_node("Pad", ["Y"], ["Z"], pads=[1, 0, 1, 3, 0, 0, 2, 0], name="pad")
536+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
537+
538+
graph = helper.make_graph(
539+
[node0, node1, node2],
540+
"transpose-pad-test",
541+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 4, 5))],
542+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (2, 6, 4, 8))],
543+
)
544+
545+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
546+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 3, 4, 5).astype(np.float32)},
547+
model_proto, remaining_transpose_num=0)
548+
549+
def test_transpose_reducemean(self):
550+
node0 = helper.make_node("Transpose", ["X"], ["Y"], perm=[0, 2, 3, 1], name="trans_1")
551+
node1 = helper.make_node("ReduceMean", ["Y"], ["Z"], axes=[1, 2], keepdims=1, name="reducemean")
552+
node2 = helper.make_node("Transpose", ["Z"], ["res"], perm=[0, 3, 1, 2], name="trans_2")
553+
554+
graph = helper.make_graph(
555+
[node0, node1, node2],
556+
"transpose-reducemean-test",
557+
[helper.make_tensor_value_info("X", TensorProto.FLOAT, (1, 3, 4, 5))],
558+
[helper.make_tensor_value_info("res", TensorProto.FLOAT, (1, 3, 1, 1))],
559+
)
560+
561+
model_proto = helper.make_model(graph, producer_name="onnx-tests")
562+
self.run_transpose_compare(["res"], {"X": np.random.randn(1, 3, 4, 5).astype(np.float32)},
563+
model_proto, remaining_transpose_num=0)
564+
533565
def test_trans_output_as_graph_outputs(self):
534566
"""
535567
If transpose's output is graph's output, don't optimize it.

0 commit comments

Comments
 (0)