Skip to content

Commit 8cbefd1

Browse files
Sand3r-Superjomn
authored andcommitted
Fuse Conv+BN+SkipConnectionAdd+ReLU with transpiler temporarily (#13350)
1 parent f00081a commit 8cbefd1

File tree

3 files changed

+87
-27
lines changed

3 files changed

+87
-27
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
300300
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
301301
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
302302
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
303+
bool fuse_eltwise = ctx.Attr<bool>("fuse_eltwise");
303304
int groups = ctx.Attr<int>("groups");
304305

305306
// TODO: add support for dilation
@@ -366,12 +367,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
366367
bias_tz = paddle::framework::vectorize2int(bias->dims());
367368
auto bias_md = platform::MKLDNNMemDesc(
368369
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
369-
conv_pd =
370-
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
371-
paddings, mkldnn_engine, fuse_relu);
370+
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
371+
strides, paddings, mkldnn_engine,
372+
fuse_relu, fuse_eltwise);
372373
} else {
373-
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
374-
paddings, mkldnn_engine, fuse_relu);
374+
conv_pd =
375+
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
376+
mkldnn_engine, fuse_relu, fuse_eltwise);
375377
}
376378
// Save conv_pd/src_memory/weights_memory for backward pass
377379
dev_ctx.SetBlob(key_conv_pd, conv_pd);
@@ -421,16 +423,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
421423
}
422424

423425
private:
424-
mkldnn::primitive_attr AddRelu() const {
425-
// Fusion with ReLU layer is executed through the PostOps feature. Create a
426-
// PostOps object and configure it to execute an eltwise relu operation.
426+
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
427+
bool fuse_eltwise) const {
427428
mkldnn::primitive_attr conv_attr;
428-
constexpr float scale = 1.0f;
429-
constexpr float negative_slope = 0.0f;
430-
constexpr float placeholder = 0.0f;
431429
mkldnn::post_ops post_operations;
432-
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
433-
negative_slope, placeholder);
430+
// Fusion with Elementwise layer relies on adding a sum post-operation with
431+
// the scale parameter. It is assumed that when fuse_eltwise is true, the
432+
// Output tensor contains the data coming from residual connection. The
433+
// result of this post_op is: Output = scale * Output + Conv_Out.
434+
if (fuse_eltwise) {
435+
post_operations.append_sum(1.0f);
436+
}
437+
// Fusion with ReLU layer is executed through the PostOps feature. Create a
438+
// PostOps object and configure it to execute an eltwise relu operation.
439+
if (fuse_relu) {
440+
constexpr float scale = 1.0f;
441+
constexpr float negative_slope = 0.0f;
442+
constexpr float placeholder = 0.0f;
443+
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
444+
negative_slope, placeholder);
445+
}
434446
conv_attr.set_post_ops(post_operations);
435447
return conv_attr;
436448
}
@@ -439,8 +451,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
439451
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
440452
const memory::desc& dst, const std::vector<int>& strides,
441453
const std::vector<int>& paddings,
442-
const mkldnn::engine& engine,
443-
const bool fuse_relu) const {
454+
const mkldnn::engine& engine, const bool fuse_relu,
455+
const bool fuse_eltwise) const {
444456
memory::dims stride_dims = {strides[0], strides[1]};
445457
memory::dims padding_dims = {paddings[0], paddings[1]};
446458

@@ -449,10 +461,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
449461
dst, stride_dims, padding_dims, padding_dims,
450462
mkldnn::padding_kind::zero);
451463

452-
mkldnn::primitive_attr conv_attr;
453-
if (fuse_relu) {
454-
conv_attr = AddRelu();
455-
}
464+
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
456465

457466
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
458467
conv_desc, conv_attr, engine);
@@ -466,8 +475,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
466475
const memory::desc& bias, const memory::desc& dst,
467476
const std::vector<int>& strides,
468477
const std::vector<int>& paddings,
469-
const mkldnn::engine& engine,
470-
const bool fuse_relu) const {
478+
const mkldnn::engine& engine, const bool fuse_relu,
479+
const bool fuse_eltwise) const {
471480
memory::dims stride_dims = {strides[0], strides[1]};
472481
memory::dims padding_dims = {paddings[0], paddings[1]};
473482

@@ -476,10 +485,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
476485
bias, dst, stride_dims, padding_dims, padding_dims,
477486
mkldnn::padding_kind::zero);
478487

479-
mkldnn::primitive_attr conv_attr;
480-
if (fuse_relu) {
481-
conv_attr = AddRelu();
482-
}
488+
mkldnn::primitive_attr conv_attr = CreatePostOps(fuse_relu, fuse_eltwise);
483489

484490
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
485491
conv_desc, conv_attr, engine);

paddle/fluid/operators/conv_op.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,11 @@ void Conv2DOpMaker::Make() {
164164
.SetDefault(false);
165165
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
166166
.SetDefault(false);
167+
AddAttr<bool>("fuse_eltwise",
168+
"(bool, default false) Only used in mkldnn kernel. Used "
169+
"whenever convolution output is connected via skip connection "
170+
"to a previous layer.")
171+
.SetDefault(false);
167172
AddAttr<std::string>(
168173
"data_format",
169174
"(string, default NCHW) Only used in "

python/paddle/fluid/transpiler/inference_transpiler.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,43 @@ def transpile(self, program, place, scope=None):
6565
if use_mkldnn:
6666
self._fuse_conv_bias_mkldnn(program)
6767
self._fuse_conv_relu_mkldnn(program)
68+
self._fuse_conv_eltwise_mkldnn(program)
69+
self._fuse_conv_relu_mkldnn(
70+
program) # ResNet residual block merging
6871
self._fuse_bn_relu_mkldnn(program)
6972

73+
def _fuse_conv_eltwise_mkldnn(self, program):
74+
'''
75+
Transpile the program fusing elementwise_add into conv for MKLDNN
76+
program. Elementwise add following convolution OP can be fused by adding
77+
'fuse_eltwise' attribute to convolution OP and replacing its output
78+
Tensor with second parameter of elementwise_add.
79+
The result of fuse is:
80+
- before:
81+
- conv->elementwise_add->any_other_op
82+
- after:
83+
- conv->any_other_op
84+
:param program: program to transpile
85+
:type program: Program
86+
'''
87+
self.block = program.block(0)
88+
89+
i = 0
90+
while i < len(self.block.ops):
91+
current_op = self.block.ops[i]
92+
if current_op.type in ['conv2d']:
93+
next_op = self.block.ops[i + 1]
94+
if next_op.type == 'elementwise_add':
95+
self._fuse_conv_eltwise(current_op, next_op)
96+
self.block._remove_op(i + 1) # Remove elementwise_add
97+
i = i + 1
98+
self._adjust_input()
99+
self._remove_unused_var()
100+
# TODO(luotao): use clone() method to flush the program.desc in force,
101+
# since some large program.desc will not be flushed immediately.
102+
# And a better solution will be considered later.
103+
program = program.clone()
104+
70105
def _fuse_conv_relu_mkldnn(self, program):
71106
'''
72107
Transpile the program by fused relu activation for MKLDNN program.
@@ -88,9 +123,9 @@ def _fuse_conv_relu_mkldnn(self, program):
88123
if current_op.type in ['conv2d']:
89124
next_op = self.block.ops[i + 1]
90125
if next_op.type == 'relu':
91-
# modify conv OP to include relu
126+
# modify bnorm OP to include relu
92127
current_op.set_attr("fuse_relu", True)
93-
# remove conv OP
128+
# remove relu OP
94129
self.block._remove_op(i + 1)
95130
i = i + 1
96131

@@ -409,6 +444,20 @@ def _fuse_conv_bias(self, index, conv_op, elementwise_add_op):
409444
outputs={"Output": out_var},
410445
attrs=attrs)
411446

447+
def _fuse_conv_eltwise(self, conv_op, eltwise_op):
448+
'''
449+
fuse the conv op with elementwise_add
450+
451+
:param conv_op: convolution operator
452+
:type conv_op: Operator
453+
:param eltwise_op: operator adding data from skip connection
454+
:type eltwise_op: Operator
455+
'''
456+
457+
conv_op.set_attr("fuse_eltwise", True)
458+
self.input_map[conv_op.output("Output")[0]] = eltwise_op.input("Y")[0]
459+
self.input_map[eltwise_op.output("Out")[0]] = eltwise_op.input("Y")[0]
460+
412461
def _adjust_input(self):
413462
for i in range(len(self.block.ops)):
414463
current_op = self.block.ops[i]

0 commit comments

Comments
 (0)