Skip to content

Commit 78f9829

Browse files
committed
conv bn fuse pass
review fix review from hshen14 fix test=develop fix error in broadcast and code cleanup rename bias -> eltwise and added macro to shorten code formatting
1 parent 8cd17c0 commit 78f9829

File tree

6 files changed

+520
-9
lines changed

6 files changed

+520
-9
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ pass_library(fc_lstm_fuse_pass inference)
3838
pass_library(embedding_fc_lstm_fuse_pass inference)
3939
pass_library(fc_gru_fuse_pass inference)
4040
pass_library(seq_concat_fc_fuse_pass inference)
41+
pass_library(conv_bn_fuse_pass inference)
4142

4243
cc_library(fuse_elewise_add_act_pass SRCS fuse_elewise_add_act_pass.cc DEPS pass graph_pattern_detector )
4344

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h"
16+
#include <functional>
17+
#include <string>
18+
#include <vector>
19+
#include "paddle/fluid/framework/lod_tensor.h"
20+
#include "paddle/fluid/operators/math/cpu_vec.h"
21+
#include "paddle/fluid/platform/enforce.h"
22+
23+
namespace paddle {
24+
namespace framework {
25+
namespace ir {
26+
27+
#define GET_CONV_BN_NODES(pattern_name) \
28+
/* OPERATORS */ \
29+
GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \
30+
GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \
31+
/* CONV inputs */ \
32+
GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \
33+
/* CONV outputs */ \
34+
GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \
35+
/* BN inputs */ \
36+
GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \
37+
GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \
38+
GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \
39+
GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \
40+
/* BN outputs */ \
41+
GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \
42+
GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \
43+
GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \
44+
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
45+
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
46+
47+
LoDTensor tensor_apply(const LoDTensor& vec, float (*f)(float)) {
48+
LoDTensor vec_y;
49+
vec_y.Resize(vec.dims());
50+
const float* x = vec.data<float>();
51+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
52+
for (int64_t i = 0; i < vec.numel(); i++) {
53+
y[i] = f(x[i]);
54+
}
55+
return vec_y;
56+
}
57+
58+
void tensor_apply_inplace(LoDTensor* vec, float (*f)(float)) {
59+
float* data = vec->mutable_data<float>(platform::CPUPlace());
60+
for (int64_t i = 0; i < vec->numel(); i++) {
61+
data[i] = f(data[i]);
62+
}
63+
}
64+
65+
template <typename BinaryOperation>
66+
LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b,
67+
BinaryOperation f) {
68+
PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims());
69+
LoDTensor vec_y;
70+
vec_y.Resize(vec_a.dims());
71+
const float* a = vec_a.data<float>();
72+
const float* b = vec_b.data<float>();
73+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
74+
for (int64_t i = 0; i < vec_a.numel(); i++) {
75+
y[i] = f(a[i], b[i]);
76+
}
77+
return vec_y;
78+
}
79+
80+
template <typename BinaryOperation>
81+
LoDTensor tensor_apply_eltwise_broadcast(const LoDTensor& vec_a,
82+
const LoDTensor& vec_b,
83+
BinaryOperation f) {
84+
PADDLE_ENFORCE_EQ(vec_a.dims().size(), 2);
85+
PADDLE_ENFORCE_EQ(vec_b.dims().size(), 2);
86+
PADDLE_ENFORCE_EQ(vec_a.dims()[0], vec_b.dims()[0]);
87+
PADDLE_ENFORCE_EQ(vec_b.dims()[1], 1);
88+
LoDTensor vec_y;
89+
vec_y.Resize(vec_a.dims());
90+
const float* a = vec_a.data<float>();
91+
const float* b = vec_b.data<float>();
92+
float* y = vec_y.mutable_data<float>(platform::CPUPlace());
93+
size_t a_height = vec_a.dims()[0];
94+
size_t a_width = vec_a.dims()[1];
95+
for (size_t h = 0; h < a_height; h++) {
96+
for (size_t w = 0; w < a_width; ++w) {
97+
*(y++) = f(*(a++), b[h]);
98+
}
99+
}
100+
return vec_y;
101+
}
102+
103+
// reshape to two dimensions {A, B * C * ...}
104+
void make_tensor_2d(LoDTensor* tensor_to_reshape) {
105+
auto dims_count = tensor_to_reshape->dims().size();
106+
PADDLE_ENFORCE_GT(dims_count, 0);
107+
108+
int size2 = 1;
109+
for (int i = 1; i < dims_count; i++) {
110+
size2 *= tensor_to_reshape->dims()[i];
111+
}
112+
tensor_to_reshape->Resize(make_ddim({tensor_to_reshape->dims()[0], size2}));
113+
}
114+
115+
void recompute_conv_weights(LoDTensor* weights, LoDTensor* tmp) {
116+
// remember the weights tensor shape {A, B, C, ...}
117+
auto weights_shape = weights->dims();
118+
// reduce the weights to 2d {A, B * C * ...}
119+
make_tensor_2d(weights);
120+
// make tmp tensor 2d by adding 1 as second dim {A, 1}
121+
make_tensor_2d(tmp);
122+
123+
*weights =
124+
tensor_apply_eltwise_broadcast(*weights, *tmp, std::multiplies<float>());
125+
// reshape weights to the original dims {A, B, C, ...}
126+
weights->Resize(weights_shape);
127+
}
128+
129+
void recompute_bias_and_weights(const Scope* scope,
130+
ir::Node* conv_weight, //
131+
const ir::Node& bn_scale, //
132+
const LoDTensor& bn_bias_tensor, //
133+
const ir::Node& bn_mean, //
134+
const ir::Node& bn_variance, //
135+
LoDTensor* eltwise_y_in_tensor) {
136+
// Re-compute bias of conv2d from BN
137+
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims());
138+
139+
auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>();
140+
auto* variance_tensor =
141+
scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>();
142+
auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>();
143+
144+
auto std_tensor = LoDTensor();
145+
std_tensor.Resize(bn_bias_tensor.dims());
146+
std_tensor =
147+
tensor_apply(*variance_tensor, [](float x) { return x + 1e-5f; });
148+
149+
tensor_apply_inplace(&std_tensor, std::sqrt);
150+
auto tmp_tensor =
151+
tensor_apply_eltwise(*scale_tensor, std_tensor, std::divides<float>());
152+
auto tensor_minus = tensor_apply_eltwise(*eltwise_y_in_tensor, *mean_tensor,
153+
std::minus<float>());
154+
auto tensor_mul =
155+
tensor_apply_eltwise(tensor_minus, tmp_tensor, std::multiplies<float>());
156+
*eltwise_y_in_tensor =
157+
tensor_apply_eltwise(tensor_mul, bn_bias_tensor, std::plus<float>());
158+
159+
// Re-compute weight of conv2d from BN
160+
auto* current_param =
161+
scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
162+
recompute_conv_weights(current_param, &tmp_tensor);
163+
}
164+
165+
std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
166+
std::unique_ptr<ir::Graph> graph) const {
167+
PADDLE_ENFORCE(graph.get());
168+
FusePassBase::Init(name_scope_, graph.get());
169+
170+
auto* scope = param_scope();
171+
PADDLE_ENFORCE(scope);
172+
173+
GraphPatternDetector gpd;
174+
auto* conv_input =
175+
gpd.mutable_pattern()
176+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
177+
->AsInput()
178+
->assert_is_op_input("conv2d", "Input");
179+
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
180+
conv_bn_pattern(conv_input, false /*with_eltwise_add*/);
181+
182+
int found_conv_bn_count = 0;
183+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
184+
Graph* g) {
185+
VLOG(4) << "handle ConvBN fuse";
186+
187+
// conv, batch_norm,
188+
// conv_weight, conv_out,
189+
// bn_scale, bn_bias, bn_mean, bn_variance,
190+
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance
191+
GET_CONV_BN_NODES(conv_bn_pattern);
192+
193+
// Create eltwise_y (conv bias) variable
194+
VarDesc eltwise_y_in_desc(
195+
patterns::PDNodeName(name_scope_, "eltwise_y_in"));
196+
auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc);
197+
auto* eltwise_y_in_tensor =
198+
scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>();
199+
200+
// Get batch norm bias
201+
auto* bn_bias_tensor =
202+
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
203+
204+
// Initialize eltwise_y
205+
eltwise_y_in_tensor->Resize(bn_bias_tensor->dims());
206+
std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()),
207+
eltwise_y_in_tensor->numel(), 0.0f);
208+
209+
// update weights and biases
210+
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
211+
*bn_mean, *bn_variance, eltwise_y_in_tensor);
212+
213+
// Create an elementwise add node
214+
OpDesc desc;
215+
desc.SetInput("X", std::vector<std::string>({conv_out->Name()}));
216+
desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()}));
217+
desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
218+
desc.SetType("elementwise_add");
219+
desc.SetAttr("axis", 1);
220+
bool a = boost::get<bool>(conv->Op()->GetAttr("use_mkldnn"));
221+
desc.SetAttr("use_mkldnn", a);
222+
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
223+
224+
GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance,
225+
batch_norm, bn_mean_out, bn_variance_out,
226+
bn_saved_mean, bn_saved_variance});
227+
228+
PADDLE_ENFORCE(subgraph.count(conv_input));
229+
IR_NODE_LINK_TO(conv_out, eltwise_op);
230+
IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op);
231+
IR_NODE_LINK_TO(eltwise_op, bn_out);
232+
233+
found_conv_bn_count++;
234+
};
235+
236+
gpd(graph.get(), handler);
237+
238+
AddStatis(found_conv_bn_count);
239+
return graph;
240+
}
241+
242+
std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
243+
std::unique_ptr<ir::Graph> graph) const {
244+
PADDLE_ENFORCE(graph.get());
245+
FusePassBase::Init(name_scope_, graph.get());
246+
247+
auto* scope = param_scope();
248+
PADDLE_ENFORCE(scope);
249+
250+
GraphPatternDetector gpd;
251+
auto* conv_input =
252+
gpd.mutable_pattern()
253+
->NewNode(patterns::PDNodeName(name_scope_, "conv_input"))
254+
->AsInput()
255+
->assert_is_op_input("conv2d", "Input");
256+
patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_);
257+
conv_bn_pattern(conv_input, true /*with_eltwise_add*/);
258+
259+
int found_conv_bn_count = 0;
260+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
261+
Graph* g) {
262+
VLOG(4) << "handle ConvBN fuse";
263+
264+
// conv, batch_norm,
265+
// conv_weight, conv_out,
266+
// bn_scale, bn_bias, bn_mean, bn_variance,
267+
// bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance
268+
GET_CONV_BN_NODES(conv_bn_pattern);
269+
// OPERATORS
270+
GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern);
271+
// BIAS inputs
272+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern);
273+
// BIAS outputs
274+
GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern);
275+
276+
// Get eltwise_y (conv bias) variable
277+
auto* eltwise_y_in_tensor =
278+
scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>();
279+
280+
// Get batch norm bias
281+
auto* bn_bias_tensor =
282+
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
283+
284+
// update weights and biases
285+
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
286+
*bn_mean, *bn_variance, eltwise_y_in_tensor);
287+
288+
// Update the elementwise_add node
289+
eltwise->Op()->SetAttr("axis", 1);
290+
eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()}));
291+
292+
GraphSafeRemoveNodes(
293+
graph.get(),
294+
{bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out,
295+
bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out});
296+
297+
PADDLE_ENFORCE(subgraph.count(conv_input));
298+
IR_NODE_LINK_TO(eltwise, bn_out);
299+
300+
found_conv_bn_count++;
301+
};
302+
303+
gpd(graph.get(), handler);
304+
305+
AddStatis(found_conv_bn_count);
306+
return graph;
307+
}
308+
309+
} // namespace ir
310+
} // namespace framework
311+
} // namespace paddle
312+
313+
REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass);
314+
REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass,
315+
paddle::framework::ir::ConvEltwiseAddBNFusePass);
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
/*
27+
* Fuse the Conv and BatchNorm to a ConvBNMKLDNNOp.
28+
*/
29+
class ConvBNFusePass : public FusePassBase {
30+
public:
31+
virtual ~ConvBNFusePass() {}
32+
33+
protected:
34+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
35+
const std::string name_scope_{"conv_bn_fuse"};
36+
};
37+
38+
class ConvEltwiseAddBNFusePass : public FusePassBase {
39+
public:
40+
virtual ~ConvEltwiseAddBNFusePass() {}
41+
42+
protected:
43+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
44+
const std::string name_scope_{"conv_eltwiseadd_bn_fuse"};
45+
};
46+
47+
} // namespace ir
48+
} // namespace framework
49+
} // namespace paddle

0 commit comments

Comments
 (0)