@@ -46,6 +46,25 @@ static std::string gethash(const memory::dims& input_dims,
46
46
dims2str (paddings) + pooling_type + suffix;
47
47
}
48
48
49
+ static inline int ComputeCeiledOutput (int input_size, int kernel_size,
50
+ int padding, int stride) {
51
+ return (input_size - kernel_size + 2 * padding) / stride + 1 ;
52
+ }
53
+
54
+ static inline void CorrectOutputSize (
55
+ const std::vector<int >& src_tz, const std::vector<int >& dst_tz,
56
+ const std::vector<int >& kernel_size, const std::vector<int >& paddings,
57
+ const std::vector<int >& strides,
58
+ std::vector<int >& right_bot_padding) { // NOLINT
59
+ for (size_t i = 0 ; i < right_bot_padding.size (); i++) {
60
+ int desired_size = ComputeCeiledOutput (src_tz[i + 2 ], kernel_size[i],
61
+ paddings[i], strides[i]);
62
+ if (desired_size != dst_tz[i + 2 ]) {
63
+ right_bot_padding[i] += strides[i];
64
+ }
65
+ }
66
+ }
67
+
49
68
template <typename T>
50
69
class PoolMKLDNNOpKernel : public paddle ::framework::OpKernel<T> {
51
70
public:
@@ -103,6 +122,13 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
103
122
auto pool_p =
104
123
std::static_pointer_cast<pooling_forward>(dev_ctx.GetBlob (key_pool_p));
105
124
if (pool_p == nullptr ) {
125
+ const std::vector<int >& padding_left_top (paddings);
126
+ std::vector<int > padding_right_bottom (paddings);
127
+ bool ceil_mode = ctx.Attr <bool >(" ceil_mode" );
128
+ if (ceil_mode) {
129
+ CorrectOutputSize (src_tz, dst_tz, ksize, paddings, strides,
130
+ padding_right_bottom);
131
+ }
106
132
auto src_md = platform::MKLDNNMemDesc (
107
133
src_tz, platform::MKLDNNGetDataType<T>(), input_format);
108
134
@@ -114,8 +140,9 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
114
140
mkldnn::memory::format::any);
115
141
116
142
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> pool_pd =
117
- CreatePrimitiveDesc (src_md, dst_md, strides, paddings, ksize,
118
- pooling_type, mkldnn_engine);
143
+ CreatePrimitiveDesc (src_md, dst_md, strides, padding_left_top,
144
+ padding_right_bottom, ksize, pooling_type,
145
+ mkldnn_engine, ceil_mode);
119
146
120
147
// save pool_pd into global device context to be referred in backward path
121
148
dev_ctx.SetBlob (key_pool_pd, pool_pd);
@@ -171,14 +198,16 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
171
198
private:
172
199
std::unique_ptr<mkldnn::pooling_forward::primitive_desc> CreatePrimitiveDesc (
173
200
const mkldnn::memory::desc& src, const mkldnn::memory::desc& dst,
174
- const std::vector<int >& stride, const std::vector<int >& padding,
175
- const std::vector<int >& kernel, const std::string& pooling_type,
176
- const mkldnn::engine& engine) const {
201
+ const std::vector<int >& stride, const std::vector<int >& padding_left_top,
202
+ const std::vector<int >& padding_right_bot, const std::vector<int >& kernel,
203
+ const std::string& pooling_type, const mkldnn::engine& engine,
204
+ bool ceil_mode) const {
177
205
auto pool_desc = mkldnn::pooling_forward::desc (
178
206
mkldnn::prop_kind::forward,
179
207
pooling_type == " max" ? mkldnn::algorithm::pooling_max
180
208
: mkldnn::algorithm::pooling_avg,
181
- src, dst, stride, kernel, padding, padding, mkldnn::padding_kind::zero);
209
+ src, dst, stride, kernel, padding_left_top, padding_right_bot,
210
+ mkldnn::padding_kind::zero);
182
211
183
212
auto p_pool_pd =
184
213
new mkldnn::pooling_forward::primitive_desc (pool_desc, engine);
0 commit comments