Skip to content

Commit 3fcca40

Browse files
committed
eigen sqrt fix and change 1e-5 to epsilon
test=develop
1 parent 78f9829 commit 3fcca40

File tree

1 file changed

+18
-6
lines changed

1 file changed

+18
-6
lines changed

paddle/fluid/framework/ir/conv_bn_fuse_pass.cc

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ namespace ir {
4444
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
4545
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
4646

47-
LoDTensor tensor_apply(const LoDTensor& vec, float (*f)(float)) {
47+
template <typename UnaryOperation>
48+
LoDTensor tensor_apply(const LoDTensor& vec, UnaryOperation f) {
4849
LoDTensor vec_y;
4950
vec_y.Resize(vec.dims());
5051
const float* x = vec.data<float>();
@@ -132,7 +133,8 @@ void recompute_bias_and_weights(const Scope* scope,
132133
const LoDTensor& bn_bias_tensor, //
133134
const ir::Node& bn_mean, //
134135
const ir::Node& bn_variance, //
135-
LoDTensor* eltwise_y_in_tensor) {
136+
LoDTensor* eltwise_y_in_tensor, //
137+
float epsilon) {
136138
// Re-compute bias of conv2d from BN
137139
PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims());
138140

@@ -144,9 +146,15 @@ void recompute_bias_and_weights(const Scope* scope,
144146
auto std_tensor = LoDTensor();
145147
std_tensor.Resize(bn_bias_tensor.dims());
146148
std_tensor =
147-
tensor_apply(*variance_tensor, [](float x) { return x + 1e-5f; });
149+
tensor_apply(*variance_tensor, [&](float x) { return x + epsilon; });
148150

149-
tensor_apply_inplace(&std_tensor, std::sqrt);
151+
using EigenVectorArrayMap =
152+
Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>;
153+
154+
EigenVectorArrayMap std_vec(
155+
std_tensor.mutable_data<float>(platform::CPUPlace()), std_tensor.numel(),
156+
1);
157+
std_vec = std_vec.sqrt();
150158
auto tmp_tensor =
151159
tensor_apply_eltwise(*scale_tensor, std_tensor, std::divides<float>());
152160
auto tensor_minus = tensor_apply_eltwise(*eltwise_y_in_tensor, *mean_tensor,
@@ -207,8 +215,10 @@ std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl(
207215
eltwise_y_in_tensor->numel(), 0.0f);
208216

209217
// update weights and biases
218+
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
210219
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
211-
*bn_mean, *bn_variance, eltwise_y_in_tensor);
220+
*bn_mean, *bn_variance, eltwise_y_in_tensor,
221+
epsilon);
212222

213223
// Create an elementwise add node
214224
OpDesc desc;
@@ -282,8 +292,10 @@ std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl(
282292
scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>();
283293

284294
// update weights and biases
295+
float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon"));
285296
recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor,
286-
*bn_mean, *bn_variance, eltwise_y_in_tensor);
297+
*bn_mean, *bn_variance, eltwise_y_in_tensor,
298+
epsilon);
287299

288300
// Update the elementwise_add node
289301
eltwise->Op()->SetAttr("axis", 1);

0 commit comments

Comments
 (0)