@@ -126,6 +126,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
126
126
pipeline);
127
127
}
128
128
129
+ std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive (
130
+ const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
131
+ std::vector<mkldnn::primitive>& pipeline) { // NOLINT
132
+ auto user_bias_pd = user_bias_memory_p->get_primitive_desc ();
133
+ auto bias_pd = conv_pd_->bias_primitive_desc ();
134
+ return this ->AcquireMemory (bias_pd, user_bias_pd, user_bias_memory_p,
135
+ " @bias_mem_p" , pipeline);
136
+ }
137
+
129
138
std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution (
130
139
std::shared_ptr<mkldnn::memory> src_memory_p,
131
140
std::shared_ptr<mkldnn::memory> weights_memory_p,
@@ -147,6 +156,28 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
147
156
return conv_p;
148
157
}
149
158
159
+ std::shared_ptr<mkldnn::convolution_forward> AcquireConvolution (
160
+ std::shared_ptr<mkldnn::memory> src_memory_p,
161
+ std::shared_ptr<mkldnn::memory> weights_memory_p,
162
+ std::shared_ptr<mkldnn::memory> bias_memory_p,
163
+ std::shared_ptr<mkldnn::memory> dst_memory_p) {
164
+ auto prim_key = key_ + " @conv_p" ;
165
+ auto conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
166
+ dev_ctx_.GetBlob (prim_key));
167
+ PADDLE_ENFORCE ((conv_p != nullptr ) || (is_reusing_ == false ),
168
+ " Fail to find convolution primitive in device context" );
169
+ if (conv_p == nullptr ) {
170
+ conv_p = std::make_shared<mkldnn::convolution_forward>(
171
+ *conv_pd_, *(src_memory_p), *(weights_memory_p.get ()),
172
+ *(bias_memory_p.get ()), *(dst_memory_p.get ()));
173
+
174
+ dev_ctx_.SetBlob (prim_key, conv_p);
175
+ } else {
176
+ is_reusing_ = true ;
177
+ }
178
+ return conv_p;
179
+ }
180
+
150
181
std::shared_ptr<mkldnn::convolution_backward_weights>
151
182
AcquireConvolutionBackwardWeights (
152
183
std::shared_ptr<mkldnn::memory> src_memory_p,
@@ -229,6 +260,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
229
260
230
261
auto * input = ctx.Input <Tensor>(" Input" );
231
262
auto * filter = ctx.Input <Tensor>(" Filter" );
263
+ auto * bias = ctx.HasInput (" Bias" ) ? ctx.Input <Tensor>(" Bias" ) : nullptr ;
232
264
auto * output = ctx.Output <Tensor>(" Output" );
233
265
234
266
PADDLE_ENFORCE (input->layout () == DataLayout::kMKLDNN &&
@@ -237,6 +269,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
237
269
PADDLE_ENFORCE (filter->layout () == DataLayout::kMKLDNN &&
238
270
filter->format () != memory::format::format_undef,
239
271
" Wrong layout/format set for Filter tensor" );
272
+ PADDLE_ENFORCE (input->dims ().size () == 4 ,
273
+ " Input must be with 4 dimensions, i.e. NCHW" );
274
+ PADDLE_ENFORCE (filter->dims ().size () == 4 ,
275
+ " Filter must be with 4 dimensions, i.e. OIHW" );
276
+ if (bias) {
277
+ PADDLE_ENFORCE (bias->layout () == DataLayout::kMKLDNN &&
278
+ bias->format () != memory::format::format_undef,
279
+ " Wrong layout/format set for Bias tensor" );
280
+ PADDLE_ENFORCE (bias->dims ().size () == 1 ,
281
+ " Bias must only have 1 dimension, i.e. X" );
282
+ }
240
283
241
284
std::vector<int > strides = ctx.Attr <std::vector<int >>(" strides" );
242
285
std::vector<int > paddings = ctx.Attr <std::vector<int >>(" paddings" );
@@ -253,11 +296,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
253
296
const T* filter_data = filter->data <T>();
254
297
T* output_data = output->mutable_data <T>(ctx.GetPlace ());
255
298
256
- PADDLE_ENFORCE (input->dims ().size () == 4 ,
257
- " Input must be with 4 dimensions, i.e. NCHW" );
258
- PADDLE_ENFORCE (filter->dims ().size () == 4 ,
259
- " Filter must be with 4 dimensions, i.e. OIHW" );
260
-
261
299
std::vector<int > src_tz = paddle::framework::vectorize2int (input->dims ());
262
300
std::vector<int > weights_tz =
263
301
paddle::framework::vectorize2int (filter->dims ());
@@ -288,13 +326,23 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
288
326
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
289
327
auto weights_md = platform::MKLDNNMemDesc (
290
328
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
329
+ std::vector<int > bias_tz; // TODO(mgallus): avoid empty vector creation.
330
+ // Currently used whenever bias is != nullptr.
291
331
auto dst_md = platform::MKLDNNMemDesc (
292
332
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
293
333
294
334
// create a conv primitive descriptor and save it for usage in backward
295
- std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
296
- ConvFwdPrimitiveDesc (src_md, weights_md, dst_md, strides, paddings,
297
- mkldnn_engine);
335
+ std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
336
+ if (bias) {
337
+ bias_tz = paddle::framework::vectorize2int (bias->dims ());
338
+ auto bias_md = platform::MKLDNNMemDesc (
339
+ bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
340
+ conv_pd = ConvFwdPrimitiveDesc (src_md, weights_md, bias_md, dst_md,
341
+ strides, paddings, mkldnn_engine);
342
+ } else {
343
+ conv_pd = ConvFwdPrimitiveDesc (src_md, weights_md, dst_md, strides,
344
+ paddings, mkldnn_engine);
345
+ }
298
346
// Save conv_pd/src_memory/weights_memory for backward pass
299
347
dev_ctx.SetBlob (key_conv_pd, conv_pd);
300
348
@@ -315,8 +363,22 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
315
363
handler.AcquireDstMemoryFromPrimitive (to_void_cast<T>(output_data));
316
364
317
365
// create convolution op primitive
318
- auto conv_p = handler.AcquireConvolution (src_memory_p, weights_memory_p,
319
- dst_memory_p);
366
+ std::shared_ptr<mkldnn::convolution_forward> conv_p;
367
+ if (bias) {
368
+ const T* bias_data = bias->data <T>();
369
+ auto user_bias_md = platform::MKLDNNMemDesc (
370
+ {bias_tz}, platform::MKLDNNGetDataType<T>(), memory::format::x);
371
+ auto user_bias_memory_p =
372
+ handler.AcquireBiasMemory (user_bias_md, to_void_cast<T>(bias_data));
373
+
374
+ auto bias_memory_p =
375
+ handler.AcquireBiasMemoryFromPrimitive (user_bias_memory_p, pipeline);
376
+ conv_p = handler.AcquireConvolution (src_memory_p, weights_memory_p,
377
+ bias_memory_p, dst_memory_p);
378
+ } else {
379
+ conv_p = handler.AcquireConvolution (src_memory_p, weights_memory_p,
380
+ dst_memory_p);
381
+ }
320
382
321
383
// push primitive to stream and wait until it's executed
322
384
pipeline.push_back (*conv_p);
@@ -346,6 +408,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
346
408
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
347
409
p_conv_pd);
348
410
}
411
+
412
+ std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
413
+ ConvFwdPrimitiveDesc (const memory::desc& src, const memory::desc& weights,
414
+ const memory::desc& bias, const memory::desc& dst,
415
+ const std::vector<int >& strides,
416
+ const std::vector<int >& paddings,
417
+ const mkldnn::engine& engine) const {
418
+ memory::dims stride_dims = {strides[0 ], strides[1 ]};
419
+ memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
420
+
421
+ auto conv_desc = mkldnn::convolution_forward::desc (
422
+ mkldnn::prop_kind::forward, mkldnn::convolution_direct, src, weights,
423
+ bias, dst, stride_dims, padding_dims, padding_dims,
424
+ mkldnn::padding_kind::zero);
425
+
426
+ auto p_conv_pd =
427
+ new mkldnn::convolution_forward::primitive_desc (conv_desc, engine);
428
+
429
+ return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
430
+ p_conv_pd);
431
+ }
349
432
};
350
433
351
434
template <typename T>
0 commit comments