Skip to content

Commit de6da10

Browse files
committed
add support for global pooling for trt
1 parent 115a6e6 commit de6da10

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

paddle/fluid/inference/tensorrt/convert/pool2d_op.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class Pool2dOpConverter : public OpConverter {
3333
PADDLE_ENFORCE_EQ(op_desc.Output("Out").size(), 1);
3434
auto* input1 = engine_->GetITensor(op_desc.Input("X")[0]);
3535

36+
bool global_pooling = boost::get<bool>(op_desc.GetAttr("global_pooling"));
3637
std::string pool_type =
3738
boost::get<std::string>(op_desc.GetAttr("pooling_type"));
3839
std::vector<int> ksize =
@@ -42,7 +43,13 @@ class Pool2dOpConverter : public OpConverter {
4243
std::vector<int> paddings =
4344
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
4445

45-
const nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
46+
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
47+
if (global_pooling == true) {
48+
nvinfer1::Dims input_shape = input1->getDimensions();
49+
int nbDims = input_shape.nbDims;
50+
nv_ksize.d[0] = input_shape.d[nbDims - 2];
51+
nv_ksize.d[1] = input_shape.d[nbDims - 1];
52+
}
4653
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
4754
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
4855

paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,13 @@ TEST(Pool2dOpConverter, main) {
4040
std::vector<int> strides({2, 2});
4141
std::vector<int> paddings({0, 0});
4242
std::string pooling_t = "max";
43+
bool global_pooling = false;
4344

4445
desc.SetAttr("pooling_type", pooling_t);
4546
desc.SetAttr("ksize", ksize);
4647
desc.SetAttr("strides", strides);
4748
desc.SetAttr("paddings", paddings);
49+
desc.SetAttr("global_pooling", global_pooling);
4850

4951
LOG(INFO) << "set OP";
5052
validator.SetOp(*desc.Proto());

0 commit comments

Comments
 (0)