Skip to content

Commit 50c5e9b

Browse files
committed
reshape_2d used from ddim.h
test=develop
1 parent 55d6950 commit 50c5e9b

File tree

1 file changed

+1
-13
lines changed

1 file changed

+1
-13
lines changed

paddle/fluid/framework/ir/conv_bn_fuse_pass.cc

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,18 +44,6 @@ namespace ir {
4444
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \
4545
GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name)
4646

47-
// reshape to two dimensions {A, B * C * ...}
48-
DDim make_dims_2d(DDim dims) {
49-
auto dims_count = dims.size();
50-
PADDLE_ENFORCE_GT(dims_count, 0);
51-
52-
int size2 = 1;
53-
for (int i = 1; i < dims_count; i++) {
54-
size2 *= dims[i];
55-
}
56-
return make_ddim({dims[0], size2});
57-
}
58-
5947
void recompute_bias_and_weights(const Scope* scope,
6048
ir::Node* conv_weight, //
6149
const ir::Node& bn_scale, //
@@ -104,7 +92,7 @@ void recompute_bias_and_weights(const Scope* scope,
10492
// Re-compute weight of conv2d from BN
10593
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
10694
auto weights_shape = weights->dims();
107-
auto weights_shape_2d = make_dims_2d(weights_shape);
95+
auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
10896

10997
EigenMatrixArrayMap weights_array_2d(
11098
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],

0 commit comments

Comments
 (0)