12
12
See the License for the specific language governing permissions and
13
13
limitations under the License. */
14
14
15
- #include " mkldnn.hpp"
16
- #include " paddle/fluid/framework/tensor.h"
17
15
#include " paddle/fluid/operators/conv_op.h"
18
16
#include " paddle/fluid/platform/mkldnn_helper.h"
19
17
20
18
namespace paddle {
21
19
namespace operators {
22
20
23
- using paddle::framework::Tensor;
24
- using paddle::platform::MKLDNNDeviceContext;
25
- using paddle::platform::MKLDNNMemDesc;
26
-
27
- using mkldnn::memory; // Note: paddle has also "memory" namespace
28
- using mkldnn::primitive;
29
- using mkldnn::convolution_forward;
30
- using mkldnn::convolution_backward_weights;
31
- using mkldnn::convolution_backward_data;
32
- using mkldnn::convolution_direct;
33
- using mkldnn::prop_kind;
34
- using mkldnn::padding_kind;
35
- using mkldnn::stream;
36
-
37
- namespace {
38
- std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
39
- ConvFwdPrimitiveDesc (const memory::desc& src, const memory::desc& weights,
40
- const memory::desc& dst, const std::vector<int >& strides,
41
- const std::vector<int >& paddings,
42
- const mkldnn::engine& engine);
43
-
44
- convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc (
45
- const memory::desc& src, const memory::desc& diff_weights,
46
- const memory::desc& diff_dst, const std::vector<int >& strides,
47
- const std::vector<int >& paddings,
48
- const convolution_forward::primitive_desc& conv_pd,
49
- const mkldnn::engine& engine);
50
-
51
- convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc (
52
- const memory::desc& diff_src, const memory::desc& weights,
53
- const memory::desc& diff_dst, const std::vector<int >& strides,
54
- const std::vector<int >& paddings,
55
- const convolution_forward::primitive_desc& conv_pd,
56
- const mkldnn::engine& engine);
57
- } // anonymous namespace
58
-
59
21
template <typename T>
60
- class ConvOpMkldnnKernel : public paddle ::framework::OpKernel<T> {
22
+ class ConvMKLDNNOpKernel : public paddle ::framework::OpKernel<T> {
61
23
public:
62
24
void Compute (const paddle::framework::ExecutionContext& ctx) const override {
63
25
PADDLE_ENFORCE (paddle::platform::is_cpu_place (ctx.GetPlace ()),
64
26
" It must use CPUPlace." );
65
27
66
- auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
28
+ auto & dev_ctx =
29
+ ctx.template device_context <paddle::platform::MKLDNNDeviceContext>();
67
30
const auto & mkldnn_engine = dev_ctx.GetEngine ();
68
31
69
32
auto * input = ctx.Input <Tensor>(" Input" );
@@ -88,7 +51,6 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel<T> {
88
51
89
52
const T* input_data = input->data <T>();
90
53
const T* filter_data = filter->data <T>();
91
- // allocate memory for output
92
54
T* output_data = output->mutable_data <T>(ctx.GetPlace ());
93
55
94
56
PADDLE_ENFORCE (input->dims ().size () == 4 ,
@@ -102,48 +64,69 @@ class ConvOpMkldnnKernel : public paddle::framework::OpKernel<T> {
102
64
std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
103
65
104
66
// TODO(pzelazko-intel): support more formats
105
- // memory descriptors for convolution src/weight/dst
106
- auto conv_src_md =
107
- MKLDNNMemDesc (src_tz, memory::data_type::f32 , memory::format::nchw);
108
- auto conv_weights_md =
109
- MKLDNNMemDesc (weights_tz, memory::data_type::f32 , memory::format::oihw);
110
- auto conv_dst_md =
111
- MKLDNNMemDesc (dst_tz, memory::data_type::f32 , memory::format::nchw);
112
-
113
- // create memory primitives
114
- auto conv_src_memory =
115
- memory ({conv_src_md, mkldnn_engine}, (void *)input_data);
116
- auto conv_weights_memory =
117
- memory ({conv_weights_md, mkldnn_engine}, (void *)filter_data);
118
- auto conv_dst_memory = memory ({conv_dst_md, mkldnn_engine}, output_data);
119
-
120
- std::unique_ptr<convolution_forward::primitive_desc> conv_pd =
121
- ConvFwdPrimitiveDesc (conv_src_md, conv_weights_md, conv_dst_md, strides,
122
- paddings, mkldnn_engine);
123
-
124
- // save p_conv_pd into dev_ctx to be referred in backward path
125
- auto p_conv_pd = conv_pd.get ();
126
- std::shared_ptr<void > conv_pd_value = std::move (conv_pd);
127
- dev_ctx.SetBlob (key_conv_pd, conv_pd_value);
67
+ auto src_md = platform::MKLDNNMemDesc (
68
+ src_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
69
+ auto weights_md =
70
+ platform::MKLDNNMemDesc (weights_tz, mkldnn::memory::data_type::f32 ,
71
+ mkldnn::memory::format::oihw);
72
+ auto dst_md = platform::MKLDNNMemDesc (
73
+ dst_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
74
+
75
+ auto src_memory =
76
+ mkldnn::memory ({src_md, mkldnn_engine}, (void *)input_data);
77
+ auto weights_memory =
78
+ mkldnn::memory ({weights_md, mkldnn_engine}, (void *)filter_data);
79
+ auto dst_memory = mkldnn::memory ({dst_md, mkldnn_engine}, output_data);
80
+
81
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
82
+ ConvFwdPrimitiveDesc (src_md, weights_md, dst_md, strides, paddings,
83
+ mkldnn_engine);
84
+
85
+ // save conv_pd into global device context to be referred in backward path
86
+ dev_ctx.SetBlob (key_conv_pd, conv_pd);
128
87
129
88
// create convolution op primitive
130
- auto conv_prim = convolution_forward (*p_conv_pd, conv_src_memory,
131
- conv_weights_memory, conv_dst_memory);
89
+ auto conv_prim = mkldnn::convolution_forward (*conv_pd, src_memory,
90
+ weights_memory, dst_memory);
91
+
92
+ // push primitive to stream and wait until it's executed
93
+ std::vector<mkldnn::primitive> pipeline{conv_prim};
94
+ mkldnn::stream (mkldnn::stream::kind::eager).submit (pipeline).wait ();
95
+ }
132
96
133
- // push op to stream and wait MKLDNN until it's executed
134
- std::vector<primitive> pipeline{conv_prim};
135
- stream (stream::kind::eager).submit (pipeline).wait ();
97
+ private:
98
+ std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
99
+ ConvFwdPrimitiveDesc (const mkldnn::memory::desc& src,
100
+ const mkldnn::memory::desc& weights,
101
+ const mkldnn::memory::desc& dst,
102
+ const std::vector<int >& strides,
103
+ const std::vector<int >& paddings,
104
+ const mkldnn::engine& engine) const {
105
+ mkldnn::memory::dims stride_dims = {strides[0 ], strides[1 ]};
106
+ mkldnn::memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
107
+
108
+ auto conv_desc = mkldnn::convolution_forward::desc (
109
+ mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
110
+ dst, stride_dims, padding_dims, padding_dims,
111
+ mkldnn::padding_kind::zero);
112
+
113
+ auto p_conv_pd =
114
+ new mkldnn::convolution_forward::primitive_desc (conv_desc, engine);
115
+
116
+ return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
117
+ p_conv_pd);
136
118
}
137
119
};
138
120
139
121
template <typename T>
140
- class ConvGradOpMkldnnKernel : public paddle ::framework::OpKernel<T> {
122
+ class ConvMKLDNNGradOpKernel : public paddle ::framework::OpKernel<T> {
141
123
public:
142
124
void Compute (const paddle::framework::ExecutionContext& ctx) const override {
143
125
PADDLE_ENFORCE (paddle::platform::is_cpu_place (ctx.GetPlace ()),
144
126
" It must use CPUPlace." );
145
127
146
- auto & dev_ctx = ctx.template device_context <MKLDNNDeviceContext>();
128
+ auto & dev_ctx =
129
+ ctx.template device_context <platform::MKLDNNDeviceContext>();
147
130
const auto & mkldnn_engine = dev_ctx.GetEngine ();
148
131
149
132
const Tensor* input = ctx.Input <Tensor>(" Input" );
@@ -170,7 +153,6 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel<T> {
170
153
T* input_grad_data = nullptr ;
171
154
T* filter_grad_data = nullptr ;
172
155
173
- // allocate memory for gradient of input/filter
174
156
if (input_grad) {
175
157
input_grad_data = input_grad->mutable_data <T>(ctx.GetPlace ());
176
158
}
@@ -184,130 +166,111 @@ class ConvGradOpMkldnnKernel : public paddle::framework::OpKernel<T> {
184
166
std::vector<int > dst_tz = paddle::framework::vectorize2int (output->dims ());
185
167
186
168
// TODO(pzelazko-intel): support more formats
187
- auto conv_src_md =
188
- MKLDNNMemDesc (src_tz, memory::data_type::f32 , memory::format::nchw);
189
- auto conv_diff_src_md =
190
- MKLDNNMemDesc (src_tz, memory::data_type::f32 , memory::format::nchw);
191
- auto conv_weights_md =
192
- MKLDNNMemDesc (weights_tz, memory::data_type::f32 , memory::format::oihw);
193
- auto conv_diff_weights_md =
194
- MKLDNNMemDesc (weights_tz, memory::data_type::f32 , memory::format::oihw);
195
- auto conv_diff_dst_md =
196
- MKLDNNMemDesc (dst_tz, memory::data_type::f32 , memory::format::nchw);
169
+ auto src_md = platform::MKLDNNMemDesc (
170
+ src_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
171
+ auto diff_src_md = platform::MKLDNNMemDesc (
172
+ src_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
173
+ auto weights_md =
174
+ platform::MKLDNNMemDesc (weights_tz, mkldnn::memory::data_type::f32 ,
175
+ mkldnn::memory::format::oihw);
176
+ auto diff_weights_md =
177
+ platform::MKLDNNMemDesc (weights_tz, mkldnn::memory::data_type::f32 ,
178
+ mkldnn::memory::format::oihw);
179
+ auto diff_dst_md = platform::MKLDNNMemDesc (
180
+ dst_tz, mkldnn::memory::data_type::f32 , mkldnn::memory::format::nchw);
197
181
198
182
// create memory
199
- auto conv_diff_dst_memory =
200
- memory ({conv_diff_weights_md, mkldnn_engine}, (void *)output_grad_data);
183
+ auto diff_dst_memory = mkldnn::memory ({diff_weights_md, mkldnn_engine},
184
+ (void *)output_grad_data);
201
185
// Retrieve conv_pd from device context
202
- std::shared_ptr<void > conv_pd;
203
- convolution_forward::primitive_desc* p_conv_pd;
204
-
205
- conv_pd = dev_ctx.GetBlob (key_conv_pd);
186
+ auto conv_pd =
187
+ std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
188
+ dev_ctx.GetBlob (key_conv_pd));
206
189
PADDLE_ENFORCE (conv_pd != nullptr ,
207
190
" Fail to find conv_pd in device context" );
208
- p_conv_pd =
209
- static_cast <convolution_forward::primitive_desc*>(conv_pd.get ());
210
191
211
192
// create backward conv primitive for weights
212
193
if (filter_grad) {
213
194
// create primitive descriptor
214
- convolution_backward_weights::primitive_desc conv_bwd_weights_pd =
215
- ConvBwdWeightsPrimitiveDesc (conv_src_md, conv_diff_weights_md ,
216
- conv_diff_dst_md, strides, paddings,
217
- *p_conv_pd, mkldnn_engine);
195
+ mkldnn:: convolution_backward_weights::primitive_desc conv_bwd_weights_pd =
196
+ ConvBwdWeightsPrimitiveDesc (src_md, diff_weights_md, diff_dst_md ,
197
+ strides, paddings, *conv_pd ,
198
+ mkldnn_engine);
218
199
219
200
// create memory
220
- auto conv_diff_weights_memory = memory (
221
- {conv_diff_weights_md , mkldnn_engine}, (void *)filter_grad_data);
222
- auto conv_src_memory =
223
- memory ({conv_src_md , mkldnn_engine}, (void *)input_data);
201
+ auto diff_weights_memory = mkldnn:: memory (
202
+ {diff_weights_md , mkldnn_engine}, (void *)filter_grad_data);
203
+ auto src_memory =
204
+ mkldnn:: memory ({src_md , mkldnn_engine}, (void *)input_data);
224
205
225
206
// create backward conv primitive for weights
226
- auto conv_bwd_weights_prim = convolution_backward_weights (
227
- conv_bwd_weights_pd, conv_src_memory, conv_diff_dst_memory ,
228
- conv_diff_weights_memory );
207
+ auto conv_bwd_weights_prim = mkldnn:: convolution_backward_weights (
208
+ conv_bwd_weights_pd, src_memory, diff_dst_memory ,
209
+ diff_weights_memory );
229
210
230
211
// push primitive and execute it
231
- std::vector<primitive> pipeline{conv_bwd_weights_prim};
232
- stream (stream::kind::eager).submit (pipeline).wait ();
212
+ std::vector<mkldnn:: primitive> pipeline{conv_bwd_weights_prim};
213
+ mkldnn:: stream (mkldnn:: stream::kind::eager).submit (pipeline).wait ();
233
214
}
234
215
235
216
if (input_grad) {
236
217
// create primitive descriptor
237
- convolution_backward_data::primitive_desc conv_bwd_data_pd =
238
- ConvBwdDataPrimitiveDesc (conv_diff_src_md, conv_weights_md,
239
- conv_diff_dst_md, strides, paddings,
240
- *p_conv_pd, mkldnn_engine);
218
+ mkldnn::convolution_backward_data::primitive_desc conv_bwd_data_pd =
219
+ ConvBwdDataPrimitiveDesc (diff_src_md, weights_md, diff_dst_md,
220
+ strides, paddings, *conv_pd, mkldnn_engine);
241
221
242
222
// create memory
243
- auto conv_diff_src_memory =
244
- memory ({conv_diff_src_md , mkldnn_engine}, (void *)input_grad_data);
245
- auto conv_weights_memory =
246
- memory ({conv_weights_md , mkldnn_engine}, (void *)filter_data);
223
+ auto diff_src_memory =
224
+ mkldnn:: memory ({diff_src_md , mkldnn_engine}, (void *)input_grad_data);
225
+ auto weights_memory =
226
+ mkldnn:: memory ({weights_md , mkldnn_engine}, (void *)filter_data);
247
227
248
228
// create backward conv primitive for data
249
- auto conv_bwd_data_prim =
250
- convolution_backward_data (conv_bwd_data_pd, conv_diff_dst_memory,
251
- conv_weights_memory, conv_diff_src_memory);
229
+ auto conv_bwd_data_prim = mkldnn::convolution_backward_data (
230
+ conv_bwd_data_pd, diff_dst_memory, weights_memory, diff_src_memory);
252
231
253
- // push primitive and execute it
254
- std::vector<primitive> pipeline{conv_bwd_data_prim};
255
- stream (stream::kind::eager).submit (pipeline).wait ();
232
+ // push primitive to stream and wait until it's executed
233
+ std::vector<mkldnn:: primitive> pipeline{conv_bwd_data_prim};
234
+ mkldnn:: stream (mkldnn:: stream::kind::eager).submit (pipeline).wait ();
256
235
}
257
236
} // Compute()
237
+
238
+ private:
239
+ mkldnn::convolution_backward_weights::primitive_desc
240
+ ConvBwdWeightsPrimitiveDesc (
241
+ const mkldnn::memory::desc& src, const mkldnn::memory::desc& diff_weights,
242
+ const mkldnn::memory::desc& diff_dst, const std::vector<int >& strides,
243
+ const std::vector<int >& paddings,
244
+ const mkldnn::convolution_forward::primitive_desc& conv_pd,
245
+ const mkldnn::engine& engine) const {
246
+ auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc (
247
+ mkldnn::convolution_direct, src, diff_weights, diff_dst, strides,
248
+ paddings, paddings, mkldnn::padding_kind::zero);
249
+ return mkldnn::convolution_backward_weights::primitive_desc (
250
+ conv_bwd_weights_desc, engine, conv_pd);
251
+ }
252
+
253
+ mkldnn::convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc (
254
+ const mkldnn::memory::desc& diff_src, const mkldnn::memory::desc& weights,
255
+ const mkldnn::memory::desc& diff_dst, const std::vector<int >& strides,
256
+ const std::vector<int >& paddings,
257
+ const mkldnn::convolution_forward::primitive_desc& conv_pd,
258
+ const mkldnn::engine& engine) const {
259
+ auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc (
260
+ mkldnn::convolution_direct, diff_src, weights, diff_dst, strides,
261
+ paddings, paddings, mkldnn::padding_kind::zero);
262
+ return mkldnn::convolution_backward_data::primitive_desc (conv_bwd_data_desc,
263
+ engine, conv_pd);
264
+ }
258
265
};
259
266
260
- namespace {
261
- std::unique_ptr<convolution_forward::primitive_desc> ConvFwdPrimitiveDesc (
262
- const memory::desc& src, const memory::desc& weights,
263
- const memory::desc& dst, const std::vector<int >& strides,
264
- const std::vector<int >& paddings, const mkldnn::engine& engine) {
265
- mkldnn::memory::dims stride_dims = {strides[0 ], strides[1 ]};
266
- mkldnn::memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
267
-
268
- auto conv_desc = mkldnn::convolution_forward::desc (
269
- mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights, dst,
270
- stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
271
-
272
- auto p_conv_pd = new convolution_forward::primitive_desc (conv_desc, engine);
273
-
274
- return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
275
- p_conv_pd);
276
- }
277
-
278
- convolution_backward_weights::primitive_desc ConvBwdWeightsPrimitiveDesc (
279
- const memory::desc& src, const memory::desc& diff_weights,
280
- const memory::desc& diff_dst, const std::vector<int >& strides,
281
- const std::vector<int >& paddings,
282
- const convolution_forward::primitive_desc& conv_pd,
283
- const mkldnn::engine& engine) {
284
- auto conv_bwd_weights_desc = convolution_backward_weights::desc (
285
- convolution_direct, src, diff_weights, diff_dst, strides, paddings,
286
- paddings, padding_kind::zero);
287
- return convolution_backward_weights::primitive_desc (conv_bwd_weights_desc,
288
- engine, conv_pd);
289
- }
290
-
291
- convolution_backward_data::primitive_desc ConvBwdDataPrimitiveDesc (
292
- const memory::desc& diff_src, const memory::desc& weights,
293
- const memory::desc& diff_dst, const std::vector<int >& strides,
294
- const std::vector<int >& paddings,
295
- const convolution_forward::primitive_desc& conv_pd,
296
- const mkldnn::engine& engine) {
297
- auto conv_bwd_data_desc = convolution_backward_data::desc (
298
- convolution_direct, diff_src, weights, diff_dst, strides, paddings,
299
- paddings, padding_kind::zero);
300
- return convolution_backward_data::primitive_desc (conv_bwd_data_desc, engine,
301
- conv_pd);
302
- }
303
- } // anonymous namespace
304
267
} // namespace operators
305
268
} // namespace paddle
306
269
307
270
namespace ops = paddle::operators;
308
271
309
272
REGISTER_OP_KERNEL (conv2d, MKLDNN, ::paddle::platform::CPUPlace,
310
- ops::ConvOpMkldnnKernel <float >);
273
+ ops::ConvMKLDNNOpKernel <float >);
311
274
312
275
REGISTER_OP_KERNEL (conv2d_grad, MKLDNN, ::paddle::platform::CPUPlace,
313
- ops::ConvGradOpMkldnnKernel <float >);
276
+ ops::ConvMKLDNNGradOpKernel <float >);
0 commit comments