Skip to content

Commit 11f5d0c

Browse files
NHZlXPaddle CI
authored andcommitted
Merge pull request #12761 from NHZlX:global_pooling_trt
1 parent 7123f0c commit 11f5d0c

File tree

1 file changed

+7
-36
lines changed

1 file changed

+7
-36
lines changed

paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,18 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
TEST(Pool2dOpConverter, main) {
23+
void test_pool2d(bool global_pooling) {
2424
framework::Scope scope;
2525
std::unordered_set<std::string> parameters;
2626
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
2727

2828
// The ITensor's Dims should not contain the batch size.
2929
// So, the ITensor's Dims of input and output should be C * H * W.
3030
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4));
31-
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2));
31+
if (global_pooling)
32+
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
33+
else
34+
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2));
3235

3336
// Prepare Op description
3437
framework::OpDesc desc;
@@ -40,7 +43,6 @@ TEST(Pool2dOpConverter, main) {
4043
std::vector<int> strides({2, 2});
4144
std::vector<int> paddings({0, 0});
4245
std::string pooling_t = "max";
43-
bool global_pooling = false;
4446

4547
desc.SetAttr("pooling_type", pooling_t);
4648
desc.SetAttr("ksize", ksize);
@@ -55,40 +57,9 @@ TEST(Pool2dOpConverter, main) {
5557
validator.Execute(3);
5658
}
5759

58-
TEST(Pool2dOpConverter, test_global_pooling) {
59-
framework::Scope scope;
60-
std::unordered_set<std::string> parameters;
61-
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
62-
63-
// The ITensor's Dims should not contain the batch size.
64-
// So, the ITensor's Dims of input and output should be C * H * W.
65-
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4));
66-
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
67-
68-
// Prepare Op description
69-
framework::OpDesc desc;
70-
desc.SetType("pool2d");
71-
desc.SetInput("X", {"pool2d-X"});
72-
desc.SetOutput("Out", {"pool2d-Out"});
73-
74-
std::vector<int> ksize({2, 2});
75-
std::vector<int> strides({2, 2});
76-
std::vector<int> paddings({0, 0});
77-
std::string pooling_t = "max";
78-
bool global_pooling = true;
60+
TEST(Pool2dOpConverter, normal) { test_pool2d(false); }
7961

80-
desc.SetAttr("pooling_type", pooling_t);
81-
desc.SetAttr("ksize", ksize);
82-
desc.SetAttr("strides", strides);
83-
desc.SetAttr("paddings", paddings);
84-
desc.SetAttr("global_pooling", global_pooling);
85-
86-
LOG(INFO) << "set OP";
87-
validator.SetOp(*desc.Proto());
88-
LOG(INFO) << "execute";
89-
90-
validator.Execute(3);
91-
}
62+
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); }
9263

9364
} // namespace tensorrt
9465
} // namespace inference

0 commit comments

Comments
 (0)