@@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
300300 std::vector<int > paddings = ctx.Attr <std::vector<int >>(" paddings" );
301301 std::vector<int > dilations = ctx.Attr <std::vector<int >>(" dilations" );
302302 bool fuse_relu = ctx.Attr <bool >(" fuse_relu" );
303+ bool fuse_eltwise = ctx.Attr <bool >(" fuse_eltwise" );
303304 int groups = ctx.Attr <int >(" groups" );
304305
305306 // TODO: add support for dilation
@@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
366367 bias_tz = paddle::framework::vectorize2int (bias->dims ());
367368 auto bias_md = platform::MKLDNNMemDesc (
368369 bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
369- conv_pd =
370- ConvFwdPrimitiveDesc (src_md, weights_md, bias_md, dst_md, strides ,
371- paddings, mkldnn_engine, fuse_relu);
370+ conv_pd = ConvFwdPrimitiveDesc (src_md, weights_md, bias_md, dst_md,
371+ strides, paddings, mkldnn_engine ,
372+ fuse_relu, fuse_eltwise );
372373 } else {
373- conv_pd = ConvFwdPrimitiveDesc (src_md, weights_md, dst_md, strides,
374- paddings, mkldnn_engine, fuse_relu);
374+ conv_pd =
375+ ConvFwdPrimitiveDesc (src_md, weights_md, dst_md, strides, paddings,
376+ mkldnn_engine, fuse_relu, fuse_eltwise);
375377 }
376378 // Save conv_pd/src_memory/weights_memory for backward pass
377379 dev_ctx.SetBlob (key_conv_pd, conv_pd);
@@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
421423 }
422424
423425 private:
424- mkldnn::primitive_attr AddRelu () const {
425- // Fusion with ReLU layer is executed through the PostOps feature. Create a
426- // PostOps object and configure it to execute an eltwise relu operation.
426+ mkldnn::primitive_attr CreatePostOps (bool fuse_relu,
427+ bool fuse_eltwise) const {
427428 mkldnn::primitive_attr conv_attr;
428- constexpr float scale = 1 .0f ;
429- constexpr float negative_slope = 0 .0f ;
430- constexpr float placeholder = 0 .0f ;
431429 mkldnn::post_ops post_operations;
432- post_operations.append_eltwise (scale, mkldnn::algorithm::eltwise_relu,
433- negative_slope, placeholder);
430+ // Fusion with Elementwise layer relies on adding a sum post-operation with
431+ // the scale parameter. It is assumed that when fuse_eltwise is true, the
432+ // Output tensor contains the data coming from residual connection. The
433+ // result of this post_op is: Output = scale * Output + Conv_Out.
434+ if (fuse_eltwise) {
435+ post_operations.append_sum (1 .0f );
436+ }
437+ // Fusion with ReLU layer is executed through the PostOps feature. Create a
438+ // PostOps object and configure it to execute an eltwise relu operation.
439+ if (fuse_relu) {
440+ constexpr float scale = 1 .0f ;
441+ constexpr float negative_slope = 0 .0f ;
442+ constexpr float placeholder = 0 .0f ;
443+ post_operations.append_eltwise (scale, mkldnn::algorithm::eltwise_relu,
444+ negative_slope, placeholder);
445+ }
434446 conv_attr.set_post_ops (post_operations);
435447 return conv_attr;
436448 }
@@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
439451 ConvFwdPrimitiveDesc (const memory::desc& src, const memory::desc& weights,
440452 const memory::desc& dst, const std::vector<int >& strides,
441453 const std::vector<int >& paddings,
442- const mkldnn::engine& engine,
443- const bool fuse_relu ) const {
454+ const mkldnn::engine& engine, const bool fuse_relu,
455+ const bool fuse_eltwise ) const {
444456 memory::dims stride_dims = {strides[0 ], strides[1 ]};
445457 memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
446458
@@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
449461 dst, stride_dims, padding_dims, padding_dims,
450462 mkldnn::padding_kind::zero);
451463
452- mkldnn::primitive_attr conv_attr;
453- if (fuse_relu) {
454- conv_attr = AddRelu ();
455- }
464+ mkldnn::primitive_attr conv_attr = CreatePostOps (fuse_relu, fuse_eltwise);
456465
457466 auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc (
458467 conv_desc, conv_attr, engine);
@@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
466475 const memory::desc& bias, const memory::desc& dst,
467476 const std::vector<int >& strides,
468477 const std::vector<int >& paddings,
469- const mkldnn::engine& engine,
470- const bool fuse_relu ) const {
478+ const mkldnn::engine& engine, const bool fuse_relu,
479+ const bool fuse_eltwise ) const {
471480 memory::dims stride_dims = {strides[0 ], strides[1 ]};
472481 memory::dims padding_dims = {paddings[0 ], paddings[1 ]};
473482
@@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
476485 bias, dst, stride_dims, padding_dims, padding_dims,
477486 mkldnn::padding_kind::zero);
478487
479- mkldnn::primitive_attr conv_attr;
480- if (fuse_relu) {
481- conv_attr = AddRelu ();
482- }
488+ mkldnn::primitive_attr conv_attr = CreatePostOps (fuse_relu, fuse_eltwise);
483489
484490 auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc (
485491 conv_desc, conv_attr, engine);
0 commit comments