Skip to content

Commit 4730a4b

Browse files
pzelazko-intelluotao1
authored andcommitted
MKLDNN pool2d OP kernel added (#8879)
* MKLDNN pool2d OP kernel added * conv2d and pool2d MKLDNN kernels renamed * MKLDNN conv2d kernel refactoring
1 parent ccc5418 commit 4730a4b

File tree

6 files changed

+417
-186
lines changed

6 files changed

+417
-186
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 126 additions & 163 deletions
Original file line numberDiff line numberDiff line change
@@ -12,58 +12,21 @@
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#include "mkldnn.hpp"
16-
#include "paddle/fluid/framework/tensor.h"
1715
#include "paddle/fluid/operators/conv_op.h"
1816
#include "paddle/fluid/platform/mkldnn_helper.h"
1917

2018
namespace paddle {
2119
namespace operators {
2220

23-
using paddle::framework::Tensor;
24-
using paddle::platform::MKLDNNDeviceContext;
25-
using paddle::platform::MKLDNNMemDesc;
26-
27-
using mkldnn::memory; // Note: paddle has also "memory" namespace
28-
using mkldnn::primitive;
29-
using mkldnn::convolution_forward;
30-
using mkldnn::convolution_backward_weights;
31-
using mkldnn::convolution_backward_data;
32-
using mkldnn::convolution_direct;
33-
using mkldnn::prop_kind;
34-
using mkldnn::padding_kind;
35-
using mkldnn::stream;
36-
37-
namespace {
38-
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
39-
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
40-
const memory::desc& dst, const std::vector<int>& strides,
41-
const std::vector<int>& paddings,
42-
const mkldnn::engine& engine);
43-
44-
convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc(
45-
const memory::desc& src, const memory::desc& diff_weights,
46-
const memory::desc& diff_dst, const std::vector<int>& strides,
47-
const std::vector<int>& paddings,
48-
const convolution_forward::primitive_desc& conv_pd,
49-
const mkldnn::engine& engine);
50-
51-
convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc(
52-
const memory::desc& diff_src, const memory::desc& weights,
53-
const memory::desc& diff_dst, const std::vector<int>& strides,
54-
const std::vector<int>& paddings,
55-
const convolution_forward::primitive_desc& conv_pd,
56-
const mkldnn::engine& engine);
57-
} // anonymous namespace
58-
5921
template <typename T>
60-
class ConvOpMkldnnKernel : public paddle::framework::OpKernel<T> {
22+
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
6123
public:
6224
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
6325
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
6426
"It must use CPUPlace.");
6527

66-
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
28+
auto& dev_ctx =
29+
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
6730
const auto& mkldnn_engine = dev_ctx.GetEngine();
6831

6932
auto* input = ctx.Input<Tensor>("Input");
@@ -88,7 +51,6 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel<T> {
8851

8952
const T* input_data = input->data<T>();
9053
const T* filter_data = filter->data<T>();
91-
// allocate memory for output
9254
T* output_data = output->mutable_data<T>(ctx.GetPlace());
9355

9456
PADDLE_ENFORCE(input->dims().size() == 4,
@@ -102,48 +64,69 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel<T> {
10264
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
10365

10466
// TODO(pzelazko-intel): support more formats
105-
// memory descriptors for convolution src/weight/dst
106-
auto conv_src_md =
107-
MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw);
108-
auto conv_weights_md =
109-
MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw);
110-
auto conv_dst_md =
111-
MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw);
112-
113-
// create memory primitives
114-
auto conv_src_memory =
115-
memory({conv_src_md, mkldnn_engine}, (void*)input_data);
116-
auto conv_weights_memory =
117-
memory({conv_weights_md, mkldnn_engine}, (void*)filter_data);
118-
auto conv_dst_memory = memory({conv_dst_md, mkldnn_engine}, output_data);
119-
120-
std::unique_ptr<convolution_forward::primitive_desc> conv_pd =
121-
ConvFwdPrimitiveDesc(conv_src_md, conv_weights_md, conv_dst_md, strides,
122-
paddings, mkldnn_engine);
123-
124-
// save p_conv_pd into dev_ctx to be referred in backward path
125-
auto p_conv_pd = conv_pd.get();
126-
std::shared_ptr<void> conv_pd_value = std::move(conv_pd);
127-
dev_ctx.SetBlob(key_conv_pd, conv_pd_value);
67+
auto src_md = platform::MKLDNNMemDesc(
68+
src_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
69+
auto weights_md =
70+
platform::MKLDNNMemDesc(weights_tz, mkldnn::memory::data_type::f32,
71+
mkldnn::memory::format::oihw);
72+
auto dst_md = platform::MKLDNNMemDesc(
73+
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
74+
75+
auto src_memory =
76+
mkldnn::memory({src_md, mkldnn_engine}, (void*)input_data);
77+
auto weights_memory =
78+
mkldnn::memory({weights_md, mkldnn_engine}, (void*)filter_data);
79+
auto dst_memory = mkldnn::memory({dst_md, mkldnn_engine}, output_data);
80+
81+
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
82+
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
83+
mkldnn_engine);
84+
85+
// save conv_pd into global device context to be referred in backward path
86+
dev_ctx.SetBlob(key_conv_pd, conv_pd);
12887

12988
// create convolution op primitive
130-
auto conv_prim = convolution_forward(*p_conv_pd, conv_src_memory,
131-
conv_weights_memory, conv_dst_memory);
89+
auto conv_prim = mkldnn::convolution_forward(*conv_pd, src_memory,
90+
weights_memory, dst_memory);
91+
92+
// push primitive to stream and wait until it's executed
93+
std::vector<mkldnn::primitive> pipeline{conv_prim};
94+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
95+
}
13296

133-
// push op to stream and wait MKLDNN until it's executed
134-
std::vector<primitive> pipeline{conv_prim};
135-
stream(stream::kind::eager).submit(pipeline).wait();
97+
private:
98+
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
99+
ConvFwdPrimitiveDesc(const mkldnn::memory::desc& src,
100+
const mkldnn::memory::desc& weights,
101+
const mkldnn::memory::desc& dst,
102+
const std::vector<int>& strides,
103+
const std::vector<int>& paddings,
104+
const mkldnn::engine& engine) const {
105+
mkldnn::memory::dims stride_dims = {strides[0], strides[1]};
106+
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]};
107+
108+
auto conv_desc = mkldnn::convolution_forward::desc(
109+
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
110+
dst, stride_dims, padding_dims, padding_dims,
111+
mkldnn::padding_kind::zero);
112+
113+
auto p_conv_pd =
114+
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
115+
116+
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
117+
p_conv_pd);
136118
}
137119
};
138120

139121
template <typename T>
140-
class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel<T> {
122+
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
141123
public:
142124
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
143125
PADDLE_ENFORCE(paddle::platform::is_cpu_place(ctx.GetPlace()),
144126
"It must use CPUPlace.");
145127

146-
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
128+
auto& dev_ctx =
129+
ctx.template device_context<platform::MKLDNNDeviceContext>();
147130
const auto& mkldnn_engine = dev_ctx.GetEngine();
148131

149132
const Tensor* input = ctx.Input<Tensor>("Input");
@@ -170,7 +153,6 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel<T> {
170153
T* input_grad_data = nullptr;
171154
T* filter_grad_data = nullptr;
172155

173-
// allocate memory for gradient of input/filter
174156
if (input_grad) {
175157
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
176158
}
@@ -184,130 +166,111 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel<T> {
184166
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
185167

186168
// TODO(pzelazko-intel): support more formats
187-
auto conv_src_md =
188-
MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw);
189-
auto conv_diff_src_md =
190-
MKLDNNMemDesc(src_tz, memory::data_type::f32, memory::format::nchw);
191-
auto conv_weights_md =
192-
MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw);
193-
auto conv_diff_weights_md =
194-
MKLDNNMemDesc(weights_tz, memory::data_type::f32, memory::format::oihw);
195-
auto conv_diff_dst_md =
196-
MKLDNNMemDesc(dst_tz, memory::data_type::f32, memory::format::nchw);
169+
auto src_md = platform::MKLDNNMemDesc(
170+
src_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
171+
auto diff_src_md = platform::MKLDNNMemDesc(
172+
src_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
173+
auto weights_md =
174+
platform::MKLDNNMemDesc(weights_tz, mkldnn::memory::data_type::f32,
175+
mkldnn::memory::format::oihw);
176+
auto diff_weights_md =
177+
platform::MKLDNNMemDesc(weights_tz, mkldnn::memory::data_type::f32,
178+
mkldnn::memory::format::oihw);
179+
auto diff_dst_md = platform::MKLDNNMemDesc(
180+
dst_tz, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
197181

198182
// create memory
199-
auto conv_diff_dst_memory =
200-
memory({conv_diff_weights_md, mkldnn_engine}, (void*)output_grad_data);
183+
auto diff_dst_memory = mkldnn::memory({diff_weights_md, mkldnn_engine},
184+
(void*)output_grad_data);
201185
// Retrieve conv_pd from device context
202-
std::shared_ptr<void> conv_pd;
203-
convolution_forward::primitive_desc* p_conv_pd;
204-
205-
conv_pd = dev_ctx.GetBlob(key_conv_pd);
186+
auto conv_pd =
187+
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
188+
dev_ctx.GetBlob(key_conv_pd));
206189
PADDLE_ENFORCE(conv_pd != nullptr,
207190
"Fail to find conv_pd in device context");
208-
p_conv_pd =
209-
static_cast<convolution_forward::primitive_desc*>(conv_pd.get());
210191

211192
// create backward conv primitive for weights
212193
if (filter_grad) {
213194
// create primitive descriptor
214-
convolution_backward_weights::primitive_desc conv_bwd_weights_pd =
215-
ConvBwdWeightsPrimitiveDesc(conv_src_md, conv_diff_weights_md,
216-
conv_diff_dst_md, strides, paddings,
217-
*p_conv_pd, mkldnn_engine);
195+
mkldnn::convolution_backward_weights::primitive_desc conv_bwd_weights_pd =
196+
ConvBwdWeightsPrimitiveDesc(src_md, diff_weights_md, diff_dst_md,
197+
strides, paddings, *conv_pd,
198+
mkldnn_engine);
218199

219200
// create memory
220-
auto conv_diff_weights_memory = memory(
221-
{conv_diff_weights_md, mkldnn_engine}, (void*)filter_grad_data);
222-
auto conv_src_memory =
223-
memory({conv_src_md, mkldnn_engine}, (void*)input_data);
201+
auto diff_weights_memory = mkldnn::memory(
202+
{diff_weights_md, mkldnn_engine}, (void*)filter_grad_data);
203+
auto src_memory =
204+
mkldnn::memory({src_md, mkldnn_engine}, (void*)input_data);
224205

225206
// create backward conv primitive for weights
226-
auto conv_bwd_weights_prim = convolution_backward_weights(
227-
conv_bwd_weights_pd, conv_src_memory, conv_diff_dst_memory,
228-
conv_diff_weights_memory);
207+
auto conv_bwd_weights_prim = mkldnn::convolution_backward_weights(
208+
conv_bwd_weights_pd, src_memory, diff_dst_memory,
209+
diff_weights_memory);
229210

230211
// push primitive and execute it
231-
std::vector<primitive> pipeline{conv_bwd_weights_prim};
232-
stream(stream::kind::eager).submit(pipeline).wait();
212+
std::vector<mkldnn::primitive> pipeline{conv_bwd_weights_prim};
213+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
233214
}
234215

235216
if (input_grad) {
236217
// create primitive descriptor
237-
convolution_backward_data::primitive_desc conv_bwd_data_pd =
238-
ConvBwdDataPrimitiveDesc(conv_diff_src_md, conv_weights_md,
239-
conv_diff_dst_md, strides, paddings,
240-
*p_conv_pd, mkldnn_engine);
218+
mkldnn::convolution_backward_data::primitive_desc conv_bwd_data_pd =
219+
ConvBwdDataPrimitiveDesc(diff_src_md, weights_md, diff_dst_md,
220+
strides, paddings, *conv_pd, mkldnn_engine);
241221

242222
// create memory
243-
auto conv_diff_src_memory =
244-
memory({conv_diff_src_md, mkldnn_engine}, (void*)input_grad_data);
245-
auto conv_weights_memory =
246-
memory({conv_weights_md, mkldnn_engine}, (void*)filter_data);
223+
auto diff_src_memory =
224+
mkldnn::memory({diff_src_md, mkldnn_engine}, (void*)input_grad_data);
225+
auto weights_memory =
226+
mkldnn::memory({weights_md, mkldnn_engine}, (void*)filter_data);
247227

248228
// create backward conv primitive for data
249-
auto conv_bwd_data_prim =
250-
convolution_backward_data(conv_bwd_data_pd, conv_diff_dst_memory,
251-
conv_weights_memory, conv_diff_src_memory);
229+
auto conv_bwd_data_prim = mkldnn::convolution_backward_data(
230+
conv_bwd_data_pd, diff_dst_memory, weights_memory, diff_src_memory);
252231

253-
// push primitive and execute it
254-
std::vector<primitive> pipeline{conv_bwd_data_prim};
255-
stream(stream::kind::eager).submit(pipeline).wait();
232+
// push primitive to stream and wait until it's executed
233+
std::vector<mkldnn::primitive> pipeline{conv_bwd_data_prim};
234+
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
256235
}
257236
} // Compute()
237+
238+
private:
239+
mkldnn::convolution_backward_weights::primitive_desc
240+
ConvBwdWeightsPrimitiveDesc(
241+
const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
242+
const mkldnn::memory::desc& diff_dst, const std::vector<int>& strides,
243+
const std::vector<int>& paddings,
244+
const mkldnn::convolution_forward::primitive_desc& conv_pd,
245+
const mkldnn::engine& engine) const {
246+
auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
247+
mkldnn::convolution_direct, src, diff_weights, diff_dst, strides,
248+
paddings, paddings, mkldnn::padding_kind::zero);
249+
return mkldnn::convolution_backward_weights::primitive_desc(
250+
conv_bwd_weights_desc, engine, conv_pd);
251+
}
252+
253+
mkldnn::convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc(
254+
const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
255+
const mkldnn::memory::desc& diff_dst, const std::vector<int>& strides,
256+
const std::vector<int>& paddings,
257+
const mkldnn::convolution_forward::primitive_desc& conv_pd,
258+
const mkldnn::engine& engine) const {
259+
auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
260+
mkldnn::convolution_direct, diff_src, weights, diff_dst, strides,
261+
paddings, paddings, mkldnn::padding_kind::zero);
262+
return mkldnn::convolution_backward_data::primitive_desc(conv_bwd_data_desc,
263+
engine, conv_pd);
264+
}
258265
};
259266

260-
namespace {
261-
std::unique_ptr<convolution_forward::primitive_desc> ConvFwdPrimitiveDesc(
262-
const memory::desc& src, const memory::desc& weights,
263-
const memory::desc& dst, const std::vector<int>& strides,
264-
const std::vector<int>& paddings, const mkldnn::engine& engine) {
265-
mkldnn::memory::dims stride_dims = {strides[0], strides[1]};
266-
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]};
267-
268-
auto conv_desc = mkldnn::convolution_forward::desc(
269-
mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, dst,
270-
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
271-
272-
auto p_conv_pd = new convolution_forward::primitive_desc(conv_desc, engine);
273-
274-
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
275-
p_conv_pd);
276-
}
277-
278-
convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc(
279-
const memory::desc& src, const memory::desc& diff_weights,
280-
const memory::desc& diff_dst, const std::vector<int>& strides,
281-
const std::vector<int>& paddings,
282-
const convolution_forward::primitive_desc& conv_pd,
283-
const mkldnn::engine& engine) {
284-
auto conv_bwd_weights_desc = convolution_backward_weights::desc(
285-
convolution_direct, src, diff_weights, diff_dst, strides, paddings,
286-
paddings, padding_kind::zero);
287-
return convolution_backward_weights::primitive_desc(conv_bwd_weights_desc,
288-
engine, conv_pd);
289-
}
290-
291-
convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc(
292-
const memory::desc& diff_src, const memory::desc& weights,
293-
const memory::desc& diff_dst, const std::vector<int>& strides,
294-
const std::vector<int>& paddings,
295-
const convolution_forward::primitive_desc& conv_pd,
296-
const mkldnn::engine& engine) {
297-
auto conv_bwd_data_desc = convolution_backward_data::desc(
298-
convolution_direct, diff_src, weights, diff_dst, strides, paddings,
299-
paddings, padding_kind::zero);
300-
return convolution_backward_data::primitive_desc(conv_bwd_data_desc, engine,
301-
conv_pd);
302-
}
303-
} // anonymous namespace
304267
} // namespace operators
305268
} // namespace paddle
306269

307270
namespace ops = paddle::operators;
308271

309272
REGISTER_OP_KERNEL(conv2d, MKLDNN, ::paddle::platform::CPUPlace,
310-
ops::ConvOpMkldnnKernel<float>);
273+
ops::ConvMKLDNNOpKernel<float>);
311274

312275
REGISTER_OP_KERNEL(conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
313-
ops::ConvGradOpMkldnnKernel<float>);
276+
ops::ConvMKLDNNGradOpKernel<float>);

0 commit comments

Comments
 (0)