Skip to content

Commit 7123f0c

Browse files
committed
add pool2d test for global_pooling true
1 parent de6da10 commit 7123f0c

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,41 @@ TEST(Pool2dOpConverter, main) {
5555
validator.Execute(3);
5656
}
5757

58+
TEST(Pool2dOpConverter, test_global_pooling) {
59+
framework::Scope scope;
60+
std::unordered_set<std::string> parameters;
61+
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
62+
63+
// The ITensor's Dims should not contain the batch size.
64+
// So, the ITensor's Dims of input and output should be C * H * W.
65+
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4));
66+
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
67+
68+
// Prepare Op description
69+
framework::OpDesc desc;
70+
desc.SetType("pool2d");
71+
desc.SetInput("X", {"pool2d-X"});
72+
desc.SetOutput("Out", {"pool2d-Out"});
73+
74+
std::vector<int> ksize({2, 2});
75+
std::vector<int> strides({2, 2});
76+
std::vector<int> paddings({0, 0});
77+
std::string pooling_t = "max";
78+
bool global_pooling = true;
79+
80+
desc.SetAttr("pooling_type", pooling_t);
81+
desc.SetAttr("ksize", ksize);
82+
desc.SetAttr("strides", strides);
83+
desc.SetAttr("paddings", paddings);
84+
desc.SetAttr("global_pooling", global_pooling);
85+
86+
LOG(INFO) << "set OP";
87+
validator.SetOp(*desc.Proto());
88+
LOG(INFO) << "execute";
89+
90+
validator.Execute(3);
91+
}
92+
5893
} // namespace tensorrt
5994
} // namespace inference
6095
} // namespace paddle

0 commit comments

Comments
 (0)