Skip to content

Commit c07b2a9

Browse files
authored
Merge pull request #13521 from Sand3r-/mgallus/fix-pooling-ceiled-size
Enable MKL-DNN in Analysis Predictor
2 parents d000008 + 0e6b303 commit c07b2a9

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

paddle/fluid/inference/api/analysis_predictor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,10 @@ bool AnalysisPredictor::Init(
7676
}
7777

7878
OptimizeInferenceProgram();
79-
ctx_ = executor_->Prepare(*inference_program_, 0);
8079
if (config_._use_mkldnn) {
8180
executor_->EnableMKLDNN(*inference_program_);
8281
}
82+
ctx_ = executor_->Prepare(*inference_program_, 0);
8383

8484
VLOG(5) << "to create variables";
8585
PADDLE_ENFORCE(scope_.get());

paddle/fluid/operators/pool_mkldnn_op.cc

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,25 @@ static std::string gethash(const memory::dims& input_dims,
4646
dims2str(paddings) + pooling_type + suffix;
4747
}
4848

49+
static inline int ComputeCeiledOutput(int input_size, int kernel_size,
50+
int padding, int stride) {
51+
return (input_size - kernel_size + 2 * padding) / stride + 1;
52+
}
53+
54+
static inline void CorrectOutputSize(
55+
const std::vector<int>& src_tz, const std::vector<int>& dst_tz,
56+
const std::vector<int>& kernel_size, const std::vector<int>& paddings,
57+
const std::vector<int>& strides,
58+
std::vector<int>& right_bot_padding) { // NOLINT
59+
for (size_t i = 0; i < right_bot_padding.size(); i++) {
60+
int desired_size = ComputeCeiledOutput(src_tz[i + 2], kernel_size[i],
61+
paddings[i], strides[i]);
62+
if (desired_size != dst_tz[i + 2]) {
63+
right_bot_padding[i] += strides[i];
64+
}
65+
}
66+
}
67+
4968
template <typename T>
5069
class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
5170
public:
@@ -103,6 +122,13 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
103122
auto pool_p =
104123
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob(key_pool_p));
105124
if (pool_p == nullptr) {
125+
const std::vector<int>& padding_left_top(paddings);
126+
std::vector<int> padding_right_bottom(paddings);
127+
bool ceil_mode = ctx.Attr<bool>("ceil_mode");
128+
if (ceil_mode) {
129+
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides,
130+
padding_right_bottom);
131+
}
106132
auto src_md = platform::MKLDNNMemDesc(
107133
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
108134

@@ -114,8 +140,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
114140
mkldnn::memory::format::any);
115141

116142
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
117-
CreatePrimitiveDesc(src_md, dst_md, strides, paddings, ksize,
118-
pooling_type, mkldnn_engine);
143+
CreatePrimitiveDesc(src_md, dst_md, strides, padding_left_top,
144+
padding_right_bottom, ksize, pooling_type,
145+
mkldnn_engine, ceil_mode);
119146

120147
// save pool_pd into global device context to be referred in backward path
121148
dev_ctx.SetBlob(key_pool_pd, pool_pd);
@@ -171,14 +198,16 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
171198
private:
172199
std::unique_ptr<mkldnn::pooling_forward::primitive_desc> CreatePrimitiveDesc(
173200
const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst,
174-
const std::vector<int>& stride, const std::vector<int>& padding,
175-
const std::vector<int>& kernel, const std::string& pooling_type,
176-
const mkldnn::engine& engine) const {
201+
const std::vector<int>& stride, const std::vector<int>& padding_left_top,
202+
const std::vector<int>& padding_right_bot, const std::vector<int>& kernel,
203+
const std::string& pooling_type, const mkldnn::engine& engine,
204+
bool ceil_mode) const {
177205
auto pool_desc = mkldnn::pooling_forward::desc(
178206
mkldnn::prop_kind::forward,
179207
pooling_type == "max" ? mkldnn::algorithm::pooling_max
180208
: mkldnn::algorithm::pooling_avg,
181-
src, dst, stride, kernel, padding, padding, mkldnn::padding_kind::zero);
209+
src, dst, stride, kernel, padding_left_top, padding_right_bot,
210+
mkldnn::padding_kind::zero);
182211

183212
auto p_pool_pd =
184213
new mkldnn::pooling_forward::primitive_desc(pool_desc, engine);

0 commit comments

Comments
 (0)