@@ -20,15 +20,18 @@ namespace paddle {
20
20
namespace inference {
21
21
namespace tensorrt {
22
22
23
- TEST (Pool2dOpConverter, main ) {
23
+ void test_pool2d ( bool global_pooling ) {
24
24
framework::Scope scope;
25
25
std::unordered_set<std::string> parameters;
26
26
TRTConvertValidation validator (5 , parameters, scope, 1 << 15 );
27
27
28
28
// The ITensor's Dims should not contain the batch size.
29
29
// So, the ITensor's Dims of input and output should be C * H * W.
30
30
validator.DeclInputVar (" pool2d-X" , nvinfer1::Dims3 (3 , 4 , 4 ));
31
- validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 2 , 2 ));
31
+ if (global_pooling)
32
+ validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 1 , 1 ));
33
+ else
34
+ validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 2 , 2 ));
32
35
33
36
// Prepare Op description
34
37
framework::OpDesc desc;
@@ -40,7 +43,6 @@ TEST(Pool2dOpConverter, main) {
40
43
std::vector<int > strides ({2 , 2 });
41
44
std::vector<int > paddings ({0 , 0 });
42
45
std::string pooling_t = " max" ;
43
- bool global_pooling = false ;
44
46
45
47
desc.SetAttr (" pooling_type" , pooling_t );
46
48
desc.SetAttr (" ksize" , ksize);
@@ -55,40 +57,9 @@ TEST(Pool2dOpConverter, main) {
55
57
validator.Execute (3 );
56
58
}
57
59
58
- TEST (Pool2dOpConverter, test_global_pooling) {
59
- framework::Scope scope;
60
- std::unordered_set<std::string> parameters;
61
- TRTConvertValidation validator (5 , parameters, scope, 1 << 15 );
62
-
63
- // The ITensor's Dims should not contain the batch size.
64
- // So, the ITensor's Dims of input and output should be C * H * W.
65
- validator.DeclInputVar (" pool2d-X" , nvinfer1::Dims3 (3 , 4 , 4 ));
66
- validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 1 , 1 ));
67
-
68
- // Prepare Op description
69
- framework::OpDesc desc;
70
- desc.SetType (" pool2d" );
71
- desc.SetInput (" X" , {" pool2d-X" });
72
- desc.SetOutput (" Out" , {" pool2d-Out" });
73
-
74
- std::vector<int > ksize ({2 , 2 });
75
- std::vector<int > strides ({2 , 2 });
76
- std::vector<int > paddings ({0 , 0 });
77
- std::string pooling_t = " max" ;
78
- bool global_pooling = true ;
60
+ TEST (Pool2dOpConverter, normal) { test_pool2d (false ); }
79
61
80
- desc.SetAttr (" pooling_type" , pooling_t );
81
- desc.SetAttr (" ksize" , ksize);
82
- desc.SetAttr (" strides" , strides);
83
- desc.SetAttr (" paddings" , paddings);
84
- desc.SetAttr (" global_pooling" , global_pooling);
85
-
86
- LOG (INFO) << " set OP" ;
87
- validator.SetOp (*desc.Proto ());
88
- LOG (INFO) << " execute" ;
89
-
90
- validator.Execute (3 );
91
- }
62
+ TEST (Pool2dOpConverter, test_global_pooling) { test_pool2d (true ); }
92
63
93
64
} // namespace tensorrt
94
65
} // namespace inference
0 commit comments