@@ -44,18 +44,6 @@ namespace ir {
44
44
GET_IR_NODE_FROM_SUBGRAPH (bn_saved_mean, bn_saved_mean, pattern_name); \
45
45
GET_IR_NODE_FROM_SUBGRAPH (bn_saved_variance, bn_saved_variance, pattern_name)
46
46
47
- // reshape to two dimensions {A, B * C * ...}
48
- DDim make_dims_2d (DDim dims) {
49
- auto dims_count = dims.size ();
50
- PADDLE_ENFORCE_GT (dims_count, 0 );
51
-
52
- int size2 = 1 ;
53
- for (int i = 1 ; i < dims_count; i++) {
54
- size2 *= dims[i];
55
- }
56
- return make_ddim ({dims[0 ], size2});
57
- }
58
-
59
47
void recompute_bias_and_weights (const Scope* scope,
60
48
ir::Node* conv_weight, //
61
49
const ir::Node& bn_scale, //
@@ -104,7 +92,7 @@ void recompute_bias_and_weights(const Scope* scope,
104
92
// Re-compute weight of conv2d from BN
105
93
auto * weights = scope->FindVar (conv_weight->Name ())->GetMutable <LoDTensor>();
106
94
auto weights_shape = weights->dims ();
107
- auto weights_shape_2d = make_dims_2d (weights_shape);
95
+ auto weights_shape_2d = flatten_to_2d (weights_shape, 1 );
108
96
109
97
EigenMatrixArrayMap weights_array_2d (
110
98
weights->mutable_data <float >(platform::CPUPlace ()), weights_shape_2d[0 ],
0 commit comments