@@ -20,26 +20,28 @@ namespace paddle {
20
20
namespace inference {
21
21
namespace tensorrt {
22
22
23
- void test_pool2d (bool global_pooling) {
23
+ void test_pool2d (bool global_pooling, bool ceil_mode ) {
24
24
framework::Scope scope;
25
25
std::unordered_set<std::string> parameters;
26
26
TRTConvertValidation validator (5 , parameters, scope, 1 << 15 );
27
27
28
28
// The ITensor's Dims should not contain the batch size.
29
29
// So, the ITensor's Dims of input and output should be C * H * W.
30
- validator.DeclInputVar (" pool2d-X" , nvinfer1::Dims3 (3 , 4 , 4 ));
30
+ validator.DeclInputVar (" pool2d-X" , nvinfer1::Dims3 (3 , 13 , 14 ));
31
31
if (global_pooling)
32
32
validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 1 , 1 ));
33
+ else if (ceil_mode)
34
+ validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 6 , 7 ));
33
35
else
34
- validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 2 , 2 ));
36
+ validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 6 , 6 ));
35
37
36
38
// Prepare Op description
37
39
framework::OpDesc desc;
38
40
desc.SetType (" pool2d" );
39
41
desc.SetInput (" X" , {" pool2d-X" });
40
42
desc.SetOutput (" Out" , {" pool2d-Out" });
41
43
42
- std::vector<int > ksize ({2 , 2 });
44
+ std::vector<int > ksize ({3 , 3 });
43
45
std::vector<int > strides ({2 , 2 });
44
46
std::vector<int > paddings ({0 , 0 });
45
47
std::string pooling_t = " max" ;
@@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) {
49
51
desc.SetAttr (" strides" , strides);
50
52
desc.SetAttr (" paddings" , paddings);
51
53
desc.SetAttr (" global_pooling" , global_pooling);
54
+ desc.SetAttr (" ceil_mode" , ceil_mode);
52
55
53
56
LOG (INFO) << " set OP" ;
54
57
validator.SetOp (*desc.Proto ());
@@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) {
57
60
validator.Execute (3 );
58
61
}
59
62
60
- TEST (Pool2dOpConverter, normal) { test_pool2d (false ); }
63
+ TEST (Pool2dOpConverter, normal) { test_pool2d (false , false ); }
64
+ TEST (Pool2dOpConverter, test_global_pooling) { test_pool2d (true , false ); }
61
65
62
- TEST (Pool2dOpConverter, test_global_pooling ) { test_pool2d (true ); }
66
+ TEST (Pool2dOpConverter, test_ceil_mode ) { test_pool2d (false , true ); }
63
67
64
68
} // namespace tensorrt
65
69
} // namespace inference
0 commit comments