@@ -55,6 +55,41 @@ TEST(Pool2dOpConverter, main) {
55
55
validator.Execute (3 );
56
56
}
57
57
58
+ TEST (Pool2dOpConverter, test_global_pooling) {
59
+ framework::Scope scope;
60
+ std::unordered_set<std::string> parameters;
61
+ TRTConvertValidation validator (5 , parameters, scope, 1 << 15 );
62
+
63
+ // The ITensor's Dims should not contain the batch size.
64
+ // So, the ITensor's Dims of input and output should be C * H * W.
65
+ validator.DeclInputVar (" pool2d-X" , nvinfer1::Dims3 (3 , 4 , 4 ));
66
+ validator.DeclOutputVar (" pool2d-Out" , nvinfer1::Dims3 (3 , 1 , 1 ));
67
+
68
+ // Prepare Op description
69
+ framework::OpDesc desc;
70
+ desc.SetType (" pool2d" );
71
+ desc.SetInput (" X" , {" pool2d-X" });
72
+ desc.SetOutput (" Out" , {" pool2d-Out" });
73
+
74
+ std::vector<int > ksize ({2 , 2 });
75
+ std::vector<int > strides ({2 , 2 });
76
+ std::vector<int > paddings ({0 , 0 });
77
+ std::string pooling_t = " max" ;
78
+ bool global_pooling = true ;
79
+
80
+ desc.SetAttr (" pooling_type" , pooling_t );
81
+ desc.SetAttr (" ksize" , ksize);
82
+ desc.SetAttr (" strides" , strides);
83
+ desc.SetAttr (" paddings" , paddings);
84
+ desc.SetAttr (" global_pooling" , global_pooling);
85
+
86
+ LOG (INFO) << " set OP" ;
87
+ validator.SetOp (*desc.Proto ());
88
+ LOG (INFO) << " execute" ;
89
+
90
+ validator.Execute (3 );
91
+ }
92
+
58
93
} // namespace tensorrt
59
94
} // namespace inference
60
95
} // namespace paddle
0 commit comments