Skip to content

Commit dfbd1cc

Browse files
authored
Merge pull request #13209 from Sand3r-/mgallus/conv-relu-fuse
[MKLDNN] Fuse Conv+BatchNorm + ReLU
2 parents 2ed7982 + 5d34ef6 commit dfbd1cc

File tree

3 files changed

+75
-12
lines changed

3 files changed

+75
-12
lines changed

paddle/fluid/operators/conv_mkldnn_op.cc

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
299299
std::vector<int> strides = ctx.Attr<std::vector<int>>("strides");
300300
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
301301
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
302+
bool fuse_relu = ctx.Attr<bool>("fuse_relu");
302303
int groups = ctx.Attr<int>("groups");
303304

304305
// TODO(pzelazko-intel) add support for group convolution and dilation
@@ -351,11 +352,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
351352
bias_tz = paddle::framework::vectorize2int(bias->dims());
352353
auto bias_md = platform::MKLDNNMemDesc(
353354
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
354-
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
355-
strides, paddings, mkldnn_engine);
355+
conv_pd =
356+
ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, strides,
357+
paddings, mkldnn_engine, fuse_relu);
356358
} else {
357359
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides,
358-
paddings, mkldnn_engine);
360+
paddings, mkldnn_engine, fuse_relu);
359361
}
360362
// Save conv_pd/src_memory/weights_memory for backward pass
361363
dev_ctx.SetBlob(key_conv_pd, conv_pd);
@@ -405,11 +407,26 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
405407
}
406408

407409
private:
410+
mkldnn::primitive_attr AddRelu() const {
411+
// Fusion with ReLU layer is executed through the PostOps feature. Create a
412+
// PostOps object and configure it to execute an eltwise relu operation.
413+
mkldnn::primitive_attr conv_attr;
414+
constexpr float scale = 1.0f;
415+
constexpr float negative_slope = 0.0f;
416+
constexpr float placeholder = 0.0f;
417+
mkldnn::post_ops post_operations;
418+
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
419+
negative_slope, placeholder);
420+
conv_attr.set_post_ops(post_operations);
421+
return conv_attr;
422+
}
423+
408424
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
409425
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
410426
const memory::desc& dst, const std::vector<int>& strides,
411427
const std::vector<int>& paddings,
412-
const mkldnn::engine& engine) const {
428+
const mkldnn::engine& engine,
429+
const bool fuse_relu) const {
413430
memory::dims stride_dims = {strides[0], strides[1]};
414431
memory::dims padding_dims = {paddings[0], paddings[1]};
415432

@@ -418,8 +435,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
418435
dst, stride_dims, padding_dims, padding_dims,
419436
mkldnn::padding_kind::zero);
420437

421-
auto p_conv_pd =
422-
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
438+
mkldnn::primitive_attr conv_attr;
439+
if (fuse_relu) {
440+
conv_attr = AddRelu();
441+
}
442+
443+
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
444+
conv_desc, conv_attr, engine);
423445

424446
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
425447
p_conv_pd);
@@ -430,7 +452,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
430452
const memory::desc& bias, const memory::desc& dst,
431453
const std::vector<int>& strides,
432454
const std::vector<int>& paddings,
433-
const mkldnn::engine& engine) const {
455+
const mkldnn::engine& engine,
456+
const bool fuse_relu) const {
434457
memory::dims stride_dims = {strides[0], strides[1]};
435458
memory::dims padding_dims = {paddings[0], paddings[1]};
436459

@@ -439,8 +462,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
439462
bias, dst, stride_dims, padding_dims, padding_dims,
440463
mkldnn::padding_kind::zero);
441464

442-
auto p_conv_pd =
443-
new mkldnn::convolution_forward::primitive_desc(conv_desc, engine);
465+
mkldnn::primitive_attr conv_attr;
466+
if (fuse_relu) {
467+
conv_attr = AddRelu();
468+
}
469+
470+
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
471+
conv_desc, conv_attr, engine);
444472

445473
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
446474
p_conv_pd);

paddle/fluid/operators/conv_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ void Conv2DOpMaker::Make() {
162162
AddAttr<bool>("use_mkldnn",
163163
"(bool, default false) Only used in mkldnn kernel")
164164
.SetDefault(false);
165+
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
166+
.SetDefault(false);
165167
AddAttr<std::string>(
166168
"data_format",
167169
"(string, default NCHW) Only used in "

python/paddle/fluid/transpiler/inference_transpiler.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,46 @@ def transpile(self, program, place, scope=None):
6060
if not isinstance(scope, core.Scope):
6161
raise TypeError("scope should be as Scope type or None")
6262
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
63+
6364
self._fuse_batch_norm(program, place, scope)
6465
if use_mkldnn:
65-
self._fuse_relu_mkldnn(program)
6666
self._fuse_conv_bias_mkldnn(program)
67+
self._fuse_conv_relu_mkldnn(program)
68+
self._fuse_bn_relu_mkldnn(program)
69+
70+
def _fuse_conv_relu_mkldnn(self, program):
71+
'''
72+
Transpile the program by fused relu activation for MKLDNN program.
73+
Relu activation following convolution OP can be fused by adding
74+
'fuse_relu' attribute to convolution OP.
75+
The result of fuse is:
76+
- before:
77+
- conv->relu->any_other_op
78+
- after:
79+
- conv->any_other_op
80+
:param program: program to transpile
81+
:type program: Program
82+
'''
83+
self.block = program.block(0)
84+
85+
i = 0
86+
while i < len(self.block.ops):
87+
current_op = self.block.ops[i]
88+
if current_op.type in ['conv2d']:
89+
next_op = self.block.ops[i + 1]
90+
if next_op.type == 'relu':
91+
# modify conv OP to include relu
92+
current_op.set_attr("fuse_relu", True)
93+
# remove conv OP
94+
self.block._remove_op(i + 1)
95+
i = i + 1
96+
97+
# TODO(luotao): use clone() method to flush the program.desc in force,
98+
# since some large program.desc will not be flushed immediately.
99+
# And a better solution will be considered later.
100+
program = program.clone()
67101

68-
def _fuse_relu_mkldnn(self, program):
102+
def _fuse_bn_relu_mkldnn(self, program):
69103
'''
70104
Transpile the program by fused relu activation for MKLDNN program.
71105
@@ -159,7 +193,6 @@ def _fuse_conv_bias_mkldnn(self, program):
159193
self._fuse_conv_bias(i, current_op, next_op)
160194
self.block._remove_op(i + 1) # Remove old conv
161195
self.block._remove_op(i + 1) # Remove elementwise_add
162-
i = i + 1
163196
i = i + 1
164197

165198
self._remove_unused_var()

0 commit comments

Comments
 (0)