Skip to content

Commit 2256fae

Browse files
authored
Merge pull request #13938 from NHZlX/ocr_attention_support
ceil pool mode support for ocr attention model.
2 parents e906c8e + 485ab5b commit 2256fae

File tree

2 files changed

+50
-10
lines changed

2 files changed

+50
-10
lines changed

paddle/fluid/inference/tensorrt/convert/pool2d_op.cc

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,22 @@ class Pool2dOpConverter : public OpConverter {
4242
boost::get<std::vector<int>>(op_desc.GetAttr("strides"));
4343
std::vector<int> paddings =
4444
boost::get<std::vector<int>>(op_desc.GetAttr("paddings"));
45+
bool ceil_mode = boost::get<bool>(op_desc.GetAttr("ceil_mode"));
4546

47+
nvinfer1::Dims input_shape = input1->getDimensions();
48+
int nbDims = input_shape.nbDims;
4649
nvinfer1::DimsHW nv_ksize(ksize[0], ksize[1]);
50+
nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
51+
nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
52+
4753
if (global_pooling == true) {
48-
nvinfer1::Dims input_shape = input1->getDimensions();
49-
int nbDims = input_shape.nbDims;
5054
nv_ksize.d[0] = input_shape.d[nbDims - 2];
5155
nv_ksize.d[1] = input_shape.d[nbDims - 1];
56+
nv_strides.h() = 1;
57+
nv_strides.w() = 1;
58+
nv_paddings.h() = 0;
59+
nv_paddings.w() = 0;
5260
}
53-
const nvinfer1::DimsHW nv_strides(strides[0], strides[1]);
54-
const nvinfer1::DimsHW nv_paddings(paddings[0], paddings[1]);
5561

5662
PADDLE_ENFORCE_EQ(input1->getDimensions().nbDims, 3UL);
5763

@@ -64,6 +70,36 @@ class Pool2dOpConverter : public OpConverter {
6470
PADDLE_THROW("TensorRT unsupported pooling type!");
6571
}
6672

73+
if (ceil_mode) {
74+
nvinfer1::DimsHW pre_pad(0, 0);
75+
nvinfer1::DimsHW post_pad(0, 0);
76+
int input_height = input_shape.d[nbDims - 2];
77+
int input_width = input_shape.d[nbDims - 1];
78+
int floor_h_output_size =
79+
(input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
80+
int ceil_h_output_size =
81+
(input_height - ksize[0] + 2 * paddings[0] + strides[0] - 1) /
82+
strides[0] +
83+
1;
84+
85+
int floor_w_output_size =
86+
(input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
87+
int ceil_w_output_size =
88+
(input_width - ksize[1] + 2 * paddings[1] + strides[1] - 1) /
89+
strides[1] +
90+
1;
91+
if (floor_h_output_size != ceil_h_output_size) {
92+
post_pad.h() = strides[0] - 1;
93+
}
94+
95+
if (floor_w_output_size != ceil_w_output_size) {
96+
post_pad.w() = strides[1] - 1;
97+
}
98+
auto* layer = TRT_ENGINE_ADD_LAYER(
99+
engine_, Padding, *const_cast<nvinfer1::ITensor*>(input1), pre_pad,
100+
post_pad);
101+
input1 = layer->getOutput(0);
102+
}
67103
auto* layer = TRT_ENGINE_ADD_LAYER(engine_, Pooling,
68104
*const_cast<nvinfer1::ITensor*>(input1),
69105
nv_pool_type, nv_ksize);

paddle/fluid/inference/tensorrt/convert/test_pool2d_op.cc

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,26 +20,28 @@ namespace paddle {
2020
namespace inference {
2121
namespace tensorrt {
2222

23-
void test_pool2d(bool global_pooling) {
23+
void test_pool2d(bool global_pooling, bool ceil_mode) {
2424
framework::Scope scope;
2525
std::unordered_set<std::string> parameters;
2626
TRTConvertValidation validator(5, parameters, scope, 1 << 15);
2727

2828
// The ITensor's Dims should not contain the batch size.
2929
// So, the ITensor's Dims of input and output should be C * H * W.
30-
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 4, 4));
30+
validator.DeclInputVar("pool2d-X", nvinfer1::Dims3(3, 13, 14));
3131
if (global_pooling)
3232
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 1, 1));
33+
else if (ceil_mode)
34+
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 7));
3335
else
34-
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 2, 2));
36+
validator.DeclOutputVar("pool2d-Out", nvinfer1::Dims3(3, 6, 6));
3537

3638
// Prepare Op description
3739
framework::OpDesc desc;
3840
desc.SetType("pool2d");
3941
desc.SetInput("X", {"pool2d-X"});
4042
desc.SetOutput("Out", {"pool2d-Out"});
4143

42-
std::vector<int> ksize({2, 2});
44+
std::vector<int> ksize({3, 3});
4345
std::vector<int> strides({2, 2});
4446
std::vector<int> paddings({0, 0});
4547
std::string pooling_t = "max";
@@ -49,6 +51,7 @@ void test_pool2d(bool global_pooling) {
4951
desc.SetAttr("strides", strides);
5052
desc.SetAttr("paddings", paddings);
5153
desc.SetAttr("global_pooling", global_pooling);
54+
desc.SetAttr("ceil_mode", ceil_mode);
5255

5356
LOG(INFO) << "set OP";
5457
validator.SetOp(*desc.Proto());
@@ -57,9 +60,10 @@ void test_pool2d(bool global_pooling) {
5760
validator.Execute(3);
5861
}
5962

60-
TEST(Pool2dOpConverter, normal) { test_pool2d(false); }
63+
TEST(Pool2dOpConverter, normal) { test_pool2d(false, false); }
64+
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true, false); }
6165

62-
TEST(Pool2dOpConverter, test_global_pooling) { test_pool2d(true); }
66+
TEST(Pool2dOpConverter, test_ceil_mode) { test_pool2d(false, true); }
6367

6468
} // namespace tensorrt
6569
} // namespace inference

0 commit comments

Comments
 (0)