Skip to content

Commit 9a15c92

Browse files
pzelazko-intelluotao1
authored andcommitted
bnorm+relu fuse for mkldnn (inference) (#11434)
* bnorm+relu fuse for mkldnn * separate fuse_relu function * bug fix * proper while range in inference_transpiler * description fix * review fix * review fix * unit test for fwd batch norm+relu MKLDNN fuse
1 parent ce5f1e0 commit 9a15c92

File tree

8 files changed

+115
-28
lines changed

8 files changed

+115
-28
lines changed

benchmark/fluid/args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,5 +122,9 @@ def parse_args():
122122
type=str,
123123
default="",
124124
help='Directory that contains all the training recordio files.')
125+
parser.add_argument(
126+
'--use_inference_transpiler',
127+
action='store_true',
128+
help='If set, uses inference transpiler to optimize the program.')
125129
args = parser.parse_args()
126130
return args

benchmark/fluid/fluid_benchmark.py

100644100755
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,11 @@ def train(avg_loss, infer_prog, optimizer, train_reader, test_reader, batch_acc,
131131
exe = fluid.Executor(place)
132132
exe.run(startup_prog)
133133

134+
# Use inference_transpiler to speedup
135+
if args.use_inference_transpiler:
136+
t = fluid.InferenceTranspiler()
137+
t.transpile(infer_prog, place)
138+
134139
if not args.use_reader_op:
135140
feed_var_list = [
136141
var for var in train_prog.global_block().vars.itervalues()

paddle/fluid/operators/batch_norm_mkldnn_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
6666
const float epsilon = ctx.Attr<float>("epsilon");
6767
const float momentum = ctx.Attr<float>("momentum");
6868
const bool is_test = ctx.Attr<bool>("is_test");
69+
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
6970

7071
const auto *x = ctx.Input<Tensor>("X");
7172
const auto *mean = ctx.Input<Tensor>("Mean");
@@ -111,6 +112,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
111112

112113
unsigned flags = mkldnn::use_scale_shift;
113114
if (is_test) flags |= mkldnn::use_global_stats;
115+
if (fuse_with_relu) flags |= mkldnn::fuse_bn_relu;
114116

115117
// create mkldnn memory from input x tensor
116118
auto src_memory =

paddle/fluid/operators/batch_norm_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,9 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
155155
AddAttr<bool>("use_mkldnn",
156156
"(bool, default false) Only used in mkldnn kernel")
157157
.SetDefault(false);
158+
AddAttr<bool>("fuse_with_relu",
159+
"(bool, default false) Only used in mkldnn kernel")
160+
.SetDefault(false);
158161
AddComment(R"DOC(
159162
Batch Normalization.
160163

python/paddle/fluid/layers/nn.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,7 +1993,8 @@ def batch_norm(input,
19931993
name=None,
19941994
moving_mean_name=None,
19951995
moving_variance_name=None,
1996-
do_model_average_for_mean_and_var=False):
1996+
do_model_average_for_mean_and_var=False,
1997+
fuse_with_relu=False):
19971998
"""
19981999
**Batch Normalization Layer**
19992000
@@ -2036,6 +2037,7 @@ def batch_norm(input,
20362037
moving_mean_name(string, Default None): The name of moving_mean which store the global Mean.
20372038
moving_variance_name(string, Default None): The name of the moving_variance which store the global Variance.
20382039
do_model_average_for_mean_and_var(bool, Default False): Do model average for mean and variance or not.
2040+
fuse_with_relu (bool): if True, this OP performs relu after batch norm.
20392041
20402042
Returns:
20412043
Variable: A tensor variable which is the result after applying batch normalization on the input.
@@ -2121,7 +2123,8 @@ def batch_norm(input,
21212123
"momentum": momentum,
21222124
"epsilon": epsilon,
21232125
"is_test": is_test,
2124-
"use_mkldnn": use_mkldnn
2126+
"use_mkldnn": use_mkldnn,
2127+
"fuse_with_relu": fuse_with_relu
21252128
})
21262129

21272130
return helper.append_activation(batch_norm_out)

python/paddle/fluid/tests/unittests/test_batch_norm_mkldnn_op.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,5 +52,17 @@ def test_check_output(self):
5252
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
5353

5454

55+
class TestMKLDNNBatchNormOpWithReluInference(TestBatchNormOpInference):
56+
def init_kernel_type(self):
57+
self.use_mkldnn = True
58+
self.fuse_with_relu = True
59+
60+
def test_check_output(self):
61+
place = core.CPUPlace()
62+
data_format = "NCHW"
63+
64+
self.check_with_place(place, data_format, self.dtype, [2, 3, 4, 5])
65+
66+
5567
if __name__ == '__main__':
5668
unittest.main()

python/paddle/fluid/tests/unittests/test_batch_norm_op.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ class TestBatchNormOpInference(unittest.TestCase):
159159
def setUp(self):
160160
self.dtype = np.float32
161161
self.use_mkldnn = False
162+
self.fuse_with_relu = False
162163
self.init_kernel_type()
163164

164165
def __assert_close(self, tensor, np_array, msg, atol=1e-4):
@@ -180,6 +181,8 @@ def check_with_place(self, place, data_layout, dtype, shape):
180181
scale_shape = [c]
181182

182183
x_val = np.random.random_sample(x_shape).astype(dtype)
184+
# generate some negative values to test case with relu fused
185+
x_val = x_val - 0.5
183186
scale_val = np.random.random_sample(scale_shape).astype(np.float32)
184187
bias_val = np.random.random_sample(scale_shape).astype(np.float32)
185188

@@ -188,6 +191,8 @@ def check_with_place(self, place, data_layout, dtype, shape):
188191

189192
y_out = _reference_testing(x_val, scale_val, bias_val, mean, variance,
190193
epsilon, data_layout).astype(dtype)
194+
if self.fuse_with_relu:
195+
y_out = np.maximum(y_out, 0)
191196

192197
scope = core.Scope()
193198

@@ -233,6 +238,7 @@ def check_with_place(self, place, data_layout, dtype, shape):
233238
is_test=True,
234239
data_layout=data_layout,
235240
use_mkldnn=self.use_mkldnn,
241+
fuse_with_relu=self.fuse_with_relu,
236242
epsilon=epsilon)
237243

238244
batch_norm_op.run(scope, place)
@@ -265,6 +271,7 @@ class TestFP16BatchNormOpInference(TestBatchNormOpInference):
265271
def setUp(self):
266272
self.dtype = np.float16
267273
self.use_mkldnn = False
274+
self.fuse_with_relu = False
268275
self.init_kernel_type()
269276

270277
def test_check_output(self):
@@ -284,6 +291,7 @@ def test_check_output(self):
284291
class TestBatchNormOpTraining(unittest.TestCase):
285292
def setUp(self):
286293
self.use_mkldnn = False
294+
self.fuse_with_relu = False
287295
self.data_formats = ["NCHW", "NHWC"]
288296
self.init_kernel_type()
289297

@@ -367,7 +375,8 @@ def test_with_place(place, data_layout, shape):
367375
"epsilon": epsilon,
368376
"is_test": False,
369377
"data_layout": data_layout,
370-
"use_mkldnn": self.use_mkldnn
378+
"use_mkldnn": self.use_mkldnn,
379+
"fuse_with_relu": self.fuse_with_relu
371380
})
372381
block.create_var(name='y@GRAD', dtype='float32', shape=y.shape)
373382

python/paddle/fluid/transpiler/inference_transpiler.py

Lines changed: 74 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
import numpy as np
1617
from .. import core
1718
from ..framework import Program
@@ -20,12 +21,15 @@
2021

2122
class InferenceTranspiler:
2223
'''
23-
Convert the fluid program to optimized inference program.
24-
25-
There are several optimizations, only fuse batch normalization is supported now.
24+
Convert the fluid program to optimized inference program.
25+
26+
There are several optimizations:
27+
28+
- fuse convolution and batch normalization
29+
- fuse batch normalization and relu (MKLDNN only)
2630
2731
Examples:
28-
32+
2933
.. code-block:: python
3034
3135
# As InferenceTranspiler will modify the original program,
@@ -54,19 +58,64 @@ def transpile(self, program, place, scope=None):
5458
if not isinstance(scope, core.Scope):
5559
raise TypeError("scope should be as Scope type or None")
5660
self.fuse_batch_norm(program, place, scope)
61+
self.fuse_relu_mkldnn(program)
62+
63+
def fuse_relu_mkldnn(self, program):
64+
'''
65+
Transpile the program by fused relu activation for MKLDNN program.
66+
67+
Relu activation following batch norm OP can be fused by adding
68+
:math:`fuse_with_relu` attribute to batch norm OP.
69+
70+
The result of fuse is:
71+
72+
- before:
73+
74+
- batch_norm->relu->any_other_op
75+
76+
- after:
77+
78+
- batch_norm->any_other_op
79+
80+
:param program: program to transpile
81+
:type program: Program
82+
'''
83+
use_mkldnn = bool(os.getenv("FLAGS_use_mkldnn", False))
84+
if not use_mkldnn:
85+
return
86+
87+
self.block = program.block(0)
88+
89+
i = 0
90+
while i < len(self.block.ops) - 1:
91+
current_op = self.block.ops[i]
92+
if current_op.type in ['batch_norm']:
93+
next_op = self.block.ops[i + 1]
94+
if next_op.type == 'relu':
95+
# modify bnorm OP to include relu
96+
current_op.set_attr("fuse_with_relu", True)
97+
# remove relu OP
98+
self.block.remove_op(i + 1)
99+
i = i + 1
100+
101+
self._remove_unused_var()
102+
# TODO(luotao): use clone() method to flush the program.desc in force,
103+
# since some large program.desc will not be flushed immediately.
104+
# And a better solution will be considered later.
105+
program = program.clone()
57106

58107
def fuse_batch_norm(self, program, place, scope):
59108
'''
60109
Transpile the program by fused batch normalization.
61-
62-
The batch normalization followed the convolution or fully connected layer
63-
can be integrated with them. Doing so will give us a forward acceleration,
110+
111+
The batch normalization followed the convolution or fully connected layer
112+
can be integrated with them. Doing so will give us a forward acceleration,
64113
especially in environments like mobile or embedded.
65-
114+
66115
For input :math:`X`:
67116
68-
- Conv process: :math:`X = input * W + bias`
69-
- Batch norm process: :math:`X' = (X - mean) / std`
117+
- Conv process: :math:`X = input * W + bias`
118+
- Batch norm process: :math:`X' = (X - mean) / std`
70119
- Scale Process: :math:`Y = a * X' + b`
71120
72121
After fuse into one operation:
@@ -76,17 +125,17 @@ def fuse_batch_norm(self, program, place, scope):
76125
Y &= (input * W + bias - mean) / std * a + b \\\\
77126
&= input * a * W / std + ((bias - mean) / std * a + b)
78127
79-
The operator transformation is:
128+
The operator transformation is:
80129
81130
- before:
82131
83132
- conv->batch_norm->any_other_op (bias == 0)
84133
- conv->elementwise_add->batch_norm->any_other_op (bias != 0)
85-
86-
- after:
134+
135+
- after:
87136
88137
- conv->elementwise_add->any_other_op
89-
138+
90139
The transpile stages are:
91140
92141
1. insert elementwise_add op when bias == 0.
@@ -99,20 +148,20 @@ def fuse_batch_norm(self, program, place, scope):
99148
program (Program): program to transpile
100149
place (Place): inference place
101150
scope (Scope): inference Scope
102-
151+
103152
'''
104153
self.scope = scope
105154
self.place = place
106155
self.block = program.block(0)
107-
self.input_map = {} # store the input names should be adjusted
156+
self.input_map = {} # store the input names should be adjusted
108157

109158
i = 0
110-
while i < len(self.block.ops):
159+
while i < len(self.block.ops) - 2:
111160
current_op = self.block.ops[i]
112161
# TODO(luotao1): consider only conv2d now. fc would be delt later.
113162
if current_op.type in ['conv2d']:
114-
# TODO(luotao1): consider single chain network now.
115-
# For branch network, we counldn't use block.ops[i + 1] as
163+
# TODO(luotao1): consider single chain network now.
164+
# For branch network, we counldn't use block.ops[i + 1] as
116165
# the judgment condition.
117166
next_op = self.block.ops[i + 1]
118167
# conv2d without bias
@@ -137,17 +186,17 @@ def fuse_batch_norm(self, program, place, scope):
137186

138187
self._adjust_input()
139188
self._remove_unused_var()
140-
# TODO(luotao): use clone() method to flush the program.desc in force,
141-
# since some large program.desc will not be flushed immediately.
189+
# TODO(luotao): use clone() method to flush the program.desc in force,
190+
# since some large program.desc will not be flushed immediately.
142191
# And a better solution will be considered later.
143192
program = program.clone()
144193

145194
# ====================== private transpiler functions =====================
146195
def _insert_bias_op(self, index, current_op, bn_op):
147196
'''
148-
Construct elementwise_add operator for adding bias
197+
Construct elementwise_add operator for adding bias
149198
and insert it into program.
150-
199+
151200
:param index: insert location of bias_op
152201
:type index: Int
153202
:param current_op: current operator (conv or fc)
@@ -175,14 +224,14 @@ def _insert_bias_op(self, index, current_op, bn_op):
175224
def _fuse_param(self, current_op, bn_op, bias_op, with_bias):
176225
'''
177226
fuse the batch_norm_op' parameters to current_op (conv or fc)
178-
227+
179228
:param current_op: current operator (conv or fc)
180229
:type current_op: Operator
181230
:param bn_op: batch norm operator
182231
:type bn_op: Operator
183232
:param bias_op: elementwise_add operator for adding bias
184233
:type bias_op: Operator
185-
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
234+
:param with_bias: If current operator has bias, with_bias = 1; otherwise 0.
186235
:type with_bias: Int
187236
'''
188237

0 commit comments

Comments
 (0)