@@ -299,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
299
299
std::vector<int > strides = ctx.Attr <std::vector<int >>(" strides" );
300
300
std::vector<int > paddings = ctx.Attr <std::vector<int >>(" paddings" );
301
301
std::vector<int > dilations = ctx.Attr <std::vector<int >>(" dilations" );
302
+ bool fuse_relu = ctx.Attr <bool >(" fuse_relu" );
302
303
int groups = ctx.Attr <int >(" groups" );
303
304
304
305
// TODO(pzelazko-intel) add support for group convolution and dilation
@@ -351,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
351
352
bias_tz = paddle::framework::vectorize2int (bias->dims ());
352
353
auto bias_md = platform::MKLDNNMemDesc (
353
354
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
354
- conv_pd = ConvFwdPrimitiveDesc (src_md, weights_md, bias_md, dst_md,
355
- strides, paddings, mkldnn_engine);
355
+ conv_pd =
356
+ ConvFwdPrimitiveDesc (src_md, weights_md, bias_md, dst_md, strides,
357
+ paddings, mkldnn_engine, fuse_relu);
356
358
} else {
357
359
conv_pd = ConvFwdPrimitiveDesc (src_md, weights_md, dst_md, strides,
358
- paddings, mkldnn_engine);
360
+ paddings, mkldnn_engine, fuse_relu );
359
361
}
360
362
// Save conv_pd/src_memory/weights_memory for backward pass
361
363
dev_ctx.SetBlob (key_conv_pd, conv_pd);
@@ -405,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
405
407
}
406
408
407
409
private:
410
+ mkldnn::primitive_attr AddRelu () const {
411
+ // Fusion with ReLU layer is executed through the PostOps feature. Create a
412
+ // PostOps object and configure it to execute an eltwise relu operation.
413
+ mkldnn::primitive_attr conv_attr;
414
+ constexpr float scale = 1 .0f ;
415
+ constexpr float negative_slope = 0 .0f ;
416
+ constexpr float placeholder = 0 .0f ;
417
+ mkldnn::post_ops post_operations;
418
+ post_operations.append_eltwise (scale, mkldnn::algorithm::eltwise_relu,
419
+ negative_slope, placeholder);
420
+ conv_attr.set_post_ops (post_operations);
421
+ return conv_attr;
422
+ }
423
+
408
424
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
409
425
ConvFwdPrimitiveDesc (const memory::desc& src, const memory::desc& weights,
410
426
const memory::desc& dst, const std::vector<int >& strides,
411
427
const std::vector<int >& paddings,
412
- const mkldnn::engine& engine) const {
428
+ const mkldnn::engine& engine,
429
+ const bool fuse_relu) const {
413
430
memory::dims stride_dims = {strides[0 ], strides[1 ]};
414
431
memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
415
432
@@ -418,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
418
435
dst, stride_dims, padding_dims, padding_dims,
419
436
mkldnn::padding_kind::zero);
420
437
421
- auto p_conv_pd =
422
- new mkldnn::convolution_forward::primitive_desc (conv_desc, engine);
438
+ mkldnn::primitive_attr conv_attr;
439
+ if (fuse_relu) {
440
+ conv_attr = AddRelu ();
441
+ }
442
+
443
+ auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc (
444
+ conv_desc, conv_attr, engine);
423
445
424
446
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
425
447
p_conv_pd);
@@ -430,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
430
452
const memory::desc& bias, const memory::desc& dst,
431
453
const std::vector<int >& strides,
432
454
const std::vector<int >& paddings,
433
- const mkldnn::engine& engine) const {
455
+ const mkldnn::engine& engine,
456
+ const bool fuse_relu) const {
434
457
memory::dims stride_dims = {strides[0 ], strides[1 ]};
435
458
memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
436
459
@@ -439,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
439
462
bias, dst, stride_dims, padding_dims, padding_dims,
440
463
mkldnn::padding_kind::zero);
441
464
442
- auto p_conv_pd =
443
- new mkldnn::convolution_forward::primitive_desc (conv_desc, engine);
465
+ mkldnn::primitive_attr conv_attr;
466
+ if (fuse_relu) {
467
+ conv_attr = AddRelu ();
468
+ }
469
+
470
+ auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc (
471
+ conv_desc, conv_attr, engine);
444
472
445
473
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
446
474
p_conv_pd);
0 commit comments