Skip to content

Commit 16b1beb

Browse files
authored
Merge pull request #13486 from sfraczek/sfraczek/conv-bn-fuse-pass
Sfraczek/conv bn fuse pass
2 parents 5d5587f + 3fcca40 commit 16b1beb

File tree

6 files changed

+532
-9
lines changed

6 files changed

+532
-9
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pass_library(fc_lstm_fuse_pass inference)
3838
pass_library(embedding_fc_lstm_fuse_pass inference)
3939
pass_library(fc_gru_fuse_pass inference)
4040
pass_library(seq_concat_fc_fuse_pass inference)
41+
pass_library(conv_bn_fuse_pass inference)
4142

4243
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
4344

Lines changed: 327 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,327 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
16+
#include <functional>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/lod_tensor.h"
20+
#include "paddle/fluid/operators/math/cpu_vec.h"
21+
#include "paddle/fluid/platform/enforce.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
namespace ir {
26+
27+
#define GET_CONV_BN_NODES(pattern_name) \
28+
/* OPERATORS */ \
29+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
30+
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \
31+
/* CONV inputs */ \
32+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
33+
/* CONV outputs */ \
34+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
35+
/* BN inputs */ \
36+
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \
37+
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \
38+
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \
39+
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \
40+
/* BN outputs */ \
41+
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \
42+
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \
43+
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
44+
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
45+
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
46+
47+
template <typename UnaryOperation>
48+
LoDTensor tensor_apply(const LoDTensor& vec, UnaryOperation f) {
49+
LoDTensor vec_y;
50+
vec_y.Resize(vec.dims());
51+
const float* x = vec.data<float>();
52+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
53+
for (int64_t i = 0; i < vec.numel(); i++) {
54+
y[i] = f(x[i]);
55+
}
56+
return vec_y;
57+
}
58+
59+
void tensor_apply_inplace(LoDTensor* vec, float (*f)(float)) {
60+
float* data = vec->mutable_data<float>(platform::CPUPlace());
61+
for (int64_t i = 0; i < vec->numel(); i++) {
62+
data[i] = f(data[i]);
63+
}
64+
}
65+
66+
template <typename BinaryOperation>
67+
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
68+
BinaryOperation f) {
69+
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
70+
LoDTensor vec_y;
71+
vec_y.Resize(vec_a.dims());
72+
const float* a = vec_a.data<float>();
73+
const float* b = vec_b.data<float>();
74+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
75+
for (int64_t i = 0; i < vec_a.numel(); i++) {
76+
y[i] = f(a[i], b[i]);
77+
}
78+
return vec_y;
79+
}
80+
81+
template <typename BinaryOperation>
82+
LoDTensor tensor_apply_eltwise_broadcast(const LoDTensor& vec_a,
83+
const LoDTensor& vec_b,
84+
BinaryOperation f) {
85+
PADDLE_ENFORCE_EQ(vec_a.dims().size(), 2);
86+
PADDLE_ENFORCE_EQ(vec_b.dims().size(), 2);
87+
PADDLE_ENFORCE_EQ(vec_a.dims()[0], vec_b.dims()[0]);
88+
PADDLE_ENFORCE_EQ(vec_b.dims()[1], 1);
89+
LoDTensor vec_y;
90+
vec_y.Resize(vec_a.dims());
91+
const float* a = vec_a.data<float>();
92+
const float* b = vec_b.data<float>();
93+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
94+
size_t a_height = vec_a.dims()[0];
95+
size_t a_width = vec_a.dims()[1];
96+
for (size_t h = 0; h < a_height; h++) {
97+
for (size_t w = 0; w < a_width; ++w) {
98+
*(y++) = f(*(a++), b[h]);
99+
}
100+
}
101+
return vec_y;
102+
}
103+
104+
// reshape to two dimensions {A, B * C * ...}
105+
void make_tensor_2d(LoDTensor* tensor_to_reshape) {
106+
auto dims_count = tensor_to_reshape->dims().size();
107+
PADDLE_ENFORCE_GT(dims_count, 0);
108+
109+
int size2 = 1;
110+
for (int i = 1; i < dims_count; i++) {
111+
size2 *= tensor_to_reshape->dims()[i];
112+
}
113+
tensor_to_reshape->Resize(make_ddim({tensor_to_reshape->dims()[0], size2}));
114+
}
115+
116+
void recompute_conv_weights(LoDTensor* weights, LoDTensor* tmp) {
117+
// remember the weights tensor shape {A, B, C, ...}
118+
auto weights_shape = weights->dims();
119+
// reduce the weights to 2d {A, B * C * ...}
120+
make_tensor_2d(weights);
121+
// make tmp tensor 2d by adding 1 as second dim {A, 1}
122+
make_tensor_2d(tmp);
123+
124+
*weights =
125+
tensor_apply_eltwise_broadcast(*weights, *tmp, std::multiplies<float>());
126+
// reshape weights to the original dims {A, B, C, ...}
127+
weights->Resize(weights_shape);
128+
}
129+
130+
void recompute_bias_and_weights(const Scope* scope,
131+
ir::Node* conv_weight, //
132+
const ir::Node& bn_scale, //
133+
const LoDTensor& bn_bias_tensor, //
134+
const ir::Node& bn_mean, //
135+
const ir::Node& bn_variance, //
136+
LoDTensor* eltwise_y_in_tensor, //
137+
float epsilon) {
138+
// Re-compute bias of conv2d from BN
139+
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims());
140+
141+
auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
142+
auto* variance_tensor =
143+
scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
144+
auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();
145+
146+
auto std_tensor = LoDTensor();
147+
std_tensor.Resize(bn_bias_tensor.dims());
148+
std_tensor =
149+
tensor_apply(*variance_tensor, [&](float x) { return x + epsilon; });
150+
151+
using EigenVectorArrayMap =
152+
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
153+
154+
EigenVectorArrayMap std_vec(
155+
std_tensor.mutable_data<float>(platform::CPUPlace()), std_tensor.numel(),
156+
1);
157+
std_vec = std_vec.sqrt();
158+
auto tmp_tensor =
159+
tensor_apply_eltwise(*scale_tensor, std_tensor, std::divides<float>());
160+
auto tensor_minus = tensor_apply_eltwise(*eltwise_y_in_tensor, *mean_tensor,
161+
std::minus<float>());
162+
auto tensor_mul =
163+
tensor_apply_eltwise(tensor_minus, tmp_tensor, std::multiplies<float>());
164+
*eltwise_y_in_tensor =
165+
tensor_apply_eltwise(tensor_mul, bn_bias_tensor, std::plus<float>());
166+
167+
// Re-compute weight of conv2d from BN
168+
auto* current_param =
169+
scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
170+
recompute_conv_weights(current_param, &tmp_tensor);
171+
}
172+
173+
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
174+
std::unique_ptr<ir::Graph> graph) const {
175+
PADDLE_ENFORCE(graph.get());
176+
FusePassBase::Init(name_scope_, graph.get());
177+
178+
auto* scope = param_scope();
179+
PADDLE_ENFORCE(scope);
180+
181+
GraphPatternDetector gpd;
182+
auto* conv_input =
183+
gpd.mutable_pattern()
184+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
185+
->AsInput()
186+
->assert_is_op_input("conv2d", "Input");
187+
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
188+
conv_bn_pattern(conv_input, false /*with_eltwise_add*/);
189+
190+
int found_conv_bn_count = 0;
191+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
192+
Graph* g) {
193+
VLOG(4) << "handle ConvBN fuse";
194+
195+
// conv, batch_norm,
196+
// conv_weight, conv_out,
197+
// bn_scale, bn_bias, bn_mean, bn_variance,
198+
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance
199+
GET_CONV_BN_NODES(conv_bn_pattern);
200+
201+
// Create eltwise_y (conv bias) variable
202+
VarDesc eltwise_y_in_desc(
203+
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
204+
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
205+
auto* eltwise_y_in_tensor =
206+
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
207+
208+
// Get batch norm bias
209+
auto* bn_bias_tensor =
210+
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
211+
212+
// Initialize eltwise_y
213+
eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
214+
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
215+
eltwise_y_in_tensor->numel(), 0.0f);
216+
217+
// update weights and biases
218+
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
219+
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
220+
*bn_mean, *bn_variance, eltwise_y_in_tensor,
221+
epsilon);
222+
223+
// Create an elementwise add node
224+
OpDesc desc;
225+
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
226+
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
227+
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
228+
desc.SetType("elementwise_add");
229+
desc.SetAttr("axis", 1);
230+
bool a = boost::get<bool>(conv->Op()->GetAttr("use_mkldnn"));
231+
desc.SetAttr("use_mkldnn", a);
232+
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
233+
234+
GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance,
235+
batch_norm, bn_mean_out, bn_variance_out,
236+
bn_saved_mean, bn_saved_variance});
237+
238+
PADDLE_ENFORCE(subgraph.count(conv_input));
239+
IR_NODE_LINK_TO(conv_out, eltwise_op);
240+
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
241+
IR_NODE_LINK_TO(eltwise_op, bn_out);
242+
243+
found_conv_bn_count++;
244+
};
245+
246+
gpd(graph.get(), handler);
247+
248+
AddStatis(found_conv_bn_count);
249+
return graph;
250+
}
251+
252+
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
253+
std::unique_ptr<ir::Graph> graph) const {
254+
PADDLE_ENFORCE(graph.get());
255+
FusePassBase::Init(name_scope_, graph.get());
256+
257+
auto* scope = param_scope();
258+
PADDLE_ENFORCE(scope);
259+
260+
GraphPatternDetector gpd;
261+
auto* conv_input =
262+
gpd.mutable_pattern()
263+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
264+
->AsInput()
265+
->assert_is_op_input("conv2d", "Input");
266+
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
267+
conv_bn_pattern(conv_input, true /*with_eltwise_add*/);
268+
269+
int found_conv_bn_count = 0;
270+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
271+
Graph* g) {
272+
VLOG(4) << "handle ConvBN fuse";
273+
274+
// conv, batch_norm,
275+
// conv_weight, conv_out,
276+
// bn_scale, bn_bias, bn_mean, bn_variance,
277+
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
278+
GET_CONV_BN_NODES(conv_bn_pattern);
279+
// OPERATORS
280+
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
281+
// BIAS inputs
282+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
283+
// BIAS outputs
284+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);
285+
286+
// Get eltwise_y (conv bias) variable
287+
auto* eltwise_y_in_tensor =
288+
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
289+
290+
// Get batch norm bias
291+
auto* bn_bias_tensor =
292+
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
293+
294+
// update weights and biases
295+
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
296+
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
297+
*bn_mean, *bn_variance, eltwise_y_in_tensor,
298+
epsilon);
299+
300+
// Update the elementwise_add node
301+
eltwise->Op()->SetAttr("axis", 1);
302+
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
303+
304+
GraphSafeRemoveNodes(
305+
graph.get(),
306+
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
307+
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
308+
309+
PADDLE_ENFORCE(subgraph.count(conv_input));
310+
IR_NODE_LINK_TO(eltwise, bn_out);
311+
312+
found_conv_bn_count++;
313+
};
314+
315+
gpd(graph.get(), handler);
316+
317+
AddStatis(found_conv_bn_count);
318+
return graph;
319+
}
320+
321+
} // namespace ir
322+
} // namespace framework
323+
} // namespace paddle
324+
325+
REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
326+
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
327+
paddle::framework::ir::ConvEltwiseAddBNFusePass);
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
/*
27+
* Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp.
28+
*/
29+
class ConvBNFusePass : public FusePassBase {
30+
public:
31+
virtual ~ConvBNFusePass() {}
32+
33+
protected:
34+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
35+
const std::string name_scope_{"conv_bn_fuse"};
36+
};
37+
38+
class ConvEltwiseAddBNFusePass : public FusePassBase {
39+
public:
40+
virtual ~ConvEltwiseAddBNFusePass() {}
41+
42+
protected:
43+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
44+
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
45+
};
46+
47+
} // namespace ir
48+
} // namespace framework
49+
} // namespace paddle

0 commit comments

Comments
 (0)