|
| 1 | +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/ir/conv_bn_fuse_pass.h" |
| 16 | +#include <functional> |
| 17 | +#include <string> |
| 18 | +#include <vector> |
| 19 | +#include "paddle/fluid/framework/lod_tensor.h" |
| 20 | +#include "paddle/fluid/operators/math/cpu_vec.h" |
| 21 | +#include "paddle/fluid/platform/enforce.h" |
| 22 | + |
| 23 | +namespace paddle { |
| 24 | +namespace framework { |
| 25 | +namespace ir { |
| 26 | + |
| 27 | +#define GET_CONV_BN_NODES(pattern_name) \ |
| 28 | + /* OPERATORS */ \ |
| 29 | + GET_IR_NODE_FROM_SUBGRAPH(conv, conv, pattern_name); \ |
| 30 | + GET_IR_NODE_FROM_SUBGRAPH(batch_norm, batch_norm, pattern_name); \ |
| 31 | + /* CONV inputs */ \ |
| 32 | + GET_IR_NODE_FROM_SUBGRAPH(conv_weight, conv_weight, pattern_name); \ |
| 33 | + /* CONV outputs */ \ |
| 34 | + GET_IR_NODE_FROM_SUBGRAPH(conv_out, conv_out, pattern_name); \ |
| 35 | + /* BN inputs */ \ |
| 36 | + GET_IR_NODE_FROM_SUBGRAPH(bn_scale, bn_scale, pattern_name); \ |
| 37 | + GET_IR_NODE_FROM_SUBGRAPH(bn_bias, bn_bias, pattern_name); \ |
| 38 | + GET_IR_NODE_FROM_SUBGRAPH(bn_mean, bn_mean, pattern_name); \ |
| 39 | + GET_IR_NODE_FROM_SUBGRAPH(bn_variance, bn_variance, pattern_name); \ |
| 40 | + /* BN outputs */ \ |
| 41 | + GET_IR_NODE_FROM_SUBGRAPH(bn_out, bn_out, pattern_name); /* Out */ \ |
| 42 | + GET_IR_NODE_FROM_SUBGRAPH(bn_mean_out, bn_mean_out, pattern_name); \ |
| 43 | + GET_IR_NODE_FROM_SUBGRAPH(bn_variance_out, bn_variance_out, pattern_name); \ |
| 44 | + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_mean, bn_saved_mean, pattern_name); \ |
| 45 | + GET_IR_NODE_FROM_SUBGRAPH(bn_saved_variance, bn_saved_variance, pattern_name) |
| 46 | + |
| 47 | +template <typename UnaryOperation> |
| 48 | +LoDTensor tensor_apply(const LoDTensor& vec, UnaryOperation f) { |
| 49 | + LoDTensor vec_y; |
| 50 | + vec_y.Resize(vec.dims()); |
| 51 | + const float* x = vec.data<float>(); |
| 52 | + float* y = vec_y.mutable_data<float>(platform::CPUPlace()); |
| 53 | + for (int64_t i = 0; i < vec.numel(); i++) { |
| 54 | + y[i] = f(x[i]); |
| 55 | + } |
| 56 | + return vec_y; |
| 57 | +} |
| 58 | + |
| 59 | +void tensor_apply_inplace(LoDTensor* vec, float (*f)(float)) { |
| 60 | + float* data = vec->mutable_data<float>(platform::CPUPlace()); |
| 61 | + for (int64_t i = 0; i < vec->numel(); i++) { |
| 62 | + data[i] = f(data[i]); |
| 63 | + } |
| 64 | +} |
| 65 | + |
| 66 | +template <typename BinaryOperation> |
| 67 | +LoDTensor tensor_apply_eltwise(const LoDTensor& vec_a, const LoDTensor& vec_b, |
| 68 | + BinaryOperation f) { |
| 69 | + PADDLE_ENFORCE_EQ(vec_a.dims(), vec_b.dims()); |
| 70 | + LoDTensor vec_y; |
| 71 | + vec_y.Resize(vec_a.dims()); |
| 72 | + const float* a = vec_a.data<float>(); |
| 73 | + const float* b = vec_b.data<float>(); |
| 74 | + float* y = vec_y.mutable_data<float>(platform::CPUPlace()); |
| 75 | + for (int64_t i = 0; i < vec_a.numel(); i++) { |
| 76 | + y[i] = f(a[i], b[i]); |
| 77 | + } |
| 78 | + return vec_y; |
| 79 | +} |
| 80 | + |
| 81 | +template <typename BinaryOperation> |
| 82 | +LoDTensor tensor_apply_eltwise_broadcast(const LoDTensor& vec_a, |
| 83 | + const LoDTensor& vec_b, |
| 84 | + BinaryOperation f) { |
| 85 | + PADDLE_ENFORCE_EQ(vec_a.dims().size(), 2); |
| 86 | + PADDLE_ENFORCE_EQ(vec_b.dims().size(), 2); |
| 87 | + PADDLE_ENFORCE_EQ(vec_a.dims()[0], vec_b.dims()[0]); |
| 88 | + PADDLE_ENFORCE_EQ(vec_b.dims()[1], 1); |
| 89 | + LoDTensor vec_y; |
| 90 | + vec_y.Resize(vec_a.dims()); |
| 91 | + const float* a = vec_a.data<float>(); |
| 92 | + const float* b = vec_b.data<float>(); |
| 93 | + float* y = vec_y.mutable_data<float>(platform::CPUPlace()); |
| 94 | + size_t a_height = vec_a.dims()[0]; |
| 95 | + size_t a_width = vec_a.dims()[1]; |
| 96 | + for (size_t h = 0; h < a_height; h++) { |
| 97 | + for (size_t w = 0; w < a_width; ++w) { |
| 98 | + *(y++) = f(*(a++), b[h]); |
| 99 | + } |
| 100 | + } |
| 101 | + return vec_y; |
| 102 | +} |
| 103 | + |
| 104 | +// reshape to two dimensions {A, B * C * ...} |
| 105 | +void make_tensor_2d(LoDTensor* tensor_to_reshape) { |
| 106 | + auto dims_count = tensor_to_reshape->dims().size(); |
| 107 | + PADDLE_ENFORCE_GT(dims_count, 0); |
| 108 | + |
| 109 | + int size2 = 1; |
| 110 | + for (int i = 1; i < dims_count; i++) { |
| 111 | + size2 *= tensor_to_reshape->dims()[i]; |
| 112 | + } |
| 113 | + tensor_to_reshape->Resize(make_ddim({tensor_to_reshape->dims()[0], size2})); |
| 114 | +} |
| 115 | + |
| 116 | +void recompute_conv_weights(LoDTensor* weights, LoDTensor* tmp) { |
| 117 | + // remember the weights tensor shape {A, B, C, ...} |
| 118 | + auto weights_shape = weights->dims(); |
| 119 | + // reduce the weights to 2d {A, B * C * ...} |
| 120 | + make_tensor_2d(weights); |
| 121 | + // make tmp tensor 2d by adding 1 as second dim {A, 1} |
| 122 | + make_tensor_2d(tmp); |
| 123 | + |
| 124 | + *weights = |
| 125 | + tensor_apply_eltwise_broadcast(*weights, *tmp, std::multiplies<float>()); |
| 126 | + // reshape weights to the original dims {A, B, C, ...} |
| 127 | + weights->Resize(weights_shape); |
| 128 | +} |
| 129 | + |
| 130 | +void recompute_bias_and_weights(const Scope* scope, |
| 131 | + ir::Node* conv_weight, // |
| 132 | + const ir::Node& bn_scale, // |
| 133 | + const LoDTensor& bn_bias_tensor, // |
| 134 | + const ir::Node& bn_mean, // |
| 135 | + const ir::Node& bn_variance, // |
| 136 | + LoDTensor* eltwise_y_in_tensor, // |
| 137 | + float epsilon) { |
| 138 | + // Re-compute bias of conv2d from BN |
| 139 | + PADDLE_ENFORCE_EQ(eltwise_y_in_tensor->dims(), bn_bias_tensor.dims()); |
| 140 | + |
| 141 | + auto* scale_tensor = scope->FindVar(bn_scale.Name())->GetMutable<LoDTensor>(); |
| 142 | + auto* variance_tensor = |
| 143 | + scope->FindVar(bn_variance.Name())->GetMutable<LoDTensor>(); |
| 144 | + auto* mean_tensor = scope->FindVar(bn_mean.Name())->GetMutable<LoDTensor>(); |
| 145 | + |
| 146 | + auto std_tensor = LoDTensor(); |
| 147 | + std_tensor.Resize(bn_bias_tensor.dims()); |
| 148 | + std_tensor = |
| 149 | + tensor_apply(*variance_tensor, [&](float x) { return x + epsilon; }); |
| 150 | + |
| 151 | + using EigenVectorArrayMap = |
| 152 | + Eigen::Map<Eigen::Array<float, Eigen::Dynamic, 1>>; |
| 153 | + |
| 154 | + EigenVectorArrayMap std_vec( |
| 155 | + std_tensor.mutable_data<float>(platform::CPUPlace()), std_tensor.numel(), |
| 156 | + 1); |
| 157 | + std_vec = std_vec.sqrt(); |
| 158 | + auto tmp_tensor = |
| 159 | + tensor_apply_eltwise(*scale_tensor, std_tensor, std::divides<float>()); |
| 160 | + auto tensor_minus = tensor_apply_eltwise(*eltwise_y_in_tensor, *mean_tensor, |
| 161 | + std::minus<float>()); |
| 162 | + auto tensor_mul = |
| 163 | + tensor_apply_eltwise(tensor_minus, tmp_tensor, std::multiplies<float>()); |
| 164 | + *eltwise_y_in_tensor = |
| 165 | + tensor_apply_eltwise(tensor_mul, bn_bias_tensor, std::plus<float>()); |
| 166 | + |
| 167 | + // Re-compute weight of conv2d from BN |
| 168 | + auto* current_param = |
| 169 | + scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>(); |
| 170 | + recompute_conv_weights(current_param, &tmp_tensor); |
| 171 | +} |
| 172 | + |
| 173 | +std::unique_ptr<ir::Graph> ConvBNFusePass::ApplyImpl( |
| 174 | + std::unique_ptr<ir::Graph> graph) const { |
| 175 | + PADDLE_ENFORCE(graph.get()); |
| 176 | + FusePassBase::Init(name_scope_, graph.get()); |
| 177 | + |
| 178 | + auto* scope = param_scope(); |
| 179 | + PADDLE_ENFORCE(scope); |
| 180 | + |
| 181 | + GraphPatternDetector gpd; |
| 182 | + auto* conv_input = |
| 183 | + gpd.mutable_pattern() |
| 184 | + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) |
| 185 | + ->AsInput() |
| 186 | + ->assert_is_op_input("conv2d", "Input"); |
| 187 | + patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); |
| 188 | + conv_bn_pattern(conv_input, false /*with_eltwise_add*/); |
| 189 | + |
| 190 | + int found_conv_bn_count = 0; |
| 191 | + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, |
| 192 | + Graph* g) { |
| 193 | + VLOG(4) << "handle ConvBN fuse"; |
| 194 | + |
| 195 | + // conv, batch_norm, |
| 196 | + // conv_weight, conv_out, |
| 197 | + // bn_scale, bn_bias, bn_mean, bn_variance, |
| 198 | + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean, bn_saved_variance |
| 199 | + GET_CONV_BN_NODES(conv_bn_pattern); |
| 200 | + |
| 201 | + // Create eltwise_y (conv bias) variable |
| 202 | + VarDesc eltwise_y_in_desc( |
| 203 | + patterns::PDNodeName(name_scope_, "eltwise_y_in")); |
| 204 | + auto* eltwise_y_in_node = g->CreateVarNode(&eltwise_y_in_desc); |
| 205 | + auto* eltwise_y_in_tensor = |
| 206 | + scope->Var(eltwise_y_in_node->Name())->GetMutable<LoDTensor>(); |
| 207 | + |
| 208 | + // Get batch norm bias |
| 209 | + auto* bn_bias_tensor = |
| 210 | + scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>(); |
| 211 | + |
| 212 | + // Initialize eltwise_y |
| 213 | + eltwise_y_in_tensor->Resize(bn_bias_tensor->dims()); |
| 214 | + std::fill_n(eltwise_y_in_tensor->mutable_data<float>(platform::CPUPlace()), |
| 215 | + eltwise_y_in_tensor->numel(), 0.0f); |
| 216 | + |
| 217 | + // update weights and biases |
| 218 | + float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon")); |
| 219 | + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, |
| 220 | + *bn_mean, *bn_variance, eltwise_y_in_tensor, |
| 221 | + epsilon); |
| 222 | + |
| 223 | + // Create an elementwise add node |
| 224 | + OpDesc desc; |
| 225 | + desc.SetInput("X", std::vector<std::string>({conv_out->Name()})); |
| 226 | + desc.SetInput("Y", std::vector<std::string>({eltwise_y_in_node->Name()})); |
| 227 | + desc.SetOutput("Out", std::vector<std::string>({bn_out->Name()})); |
| 228 | + desc.SetType("elementwise_add"); |
| 229 | + desc.SetAttr("axis", 1); |
| 230 | + bool a = boost::get<bool>(conv->Op()->GetAttr("use_mkldnn")); |
| 231 | + desc.SetAttr("use_mkldnn", a); |
| 232 | + auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. |
| 233 | + |
| 234 | + GraphSafeRemoveNodes(graph.get(), {bn_scale, bn_bias, bn_mean, bn_variance, |
| 235 | + batch_norm, bn_mean_out, bn_variance_out, |
| 236 | + bn_saved_mean, bn_saved_variance}); |
| 237 | + |
| 238 | + PADDLE_ENFORCE(subgraph.count(conv_input)); |
| 239 | + IR_NODE_LINK_TO(conv_out, eltwise_op); |
| 240 | + IR_NODE_LINK_TO(eltwise_y_in_node, eltwise_op); |
| 241 | + IR_NODE_LINK_TO(eltwise_op, bn_out); |
| 242 | + |
| 243 | + found_conv_bn_count++; |
| 244 | + }; |
| 245 | + |
| 246 | + gpd(graph.get(), handler); |
| 247 | + |
| 248 | + AddStatis(found_conv_bn_count); |
| 249 | + return graph; |
| 250 | +} |
| 251 | + |
| 252 | +std::unique_ptr<ir::Graph> ConvEltwiseAddBNFusePass::ApplyImpl( |
| 253 | + std::unique_ptr<ir::Graph> graph) const { |
| 254 | + PADDLE_ENFORCE(graph.get()); |
| 255 | + FusePassBase::Init(name_scope_, graph.get()); |
| 256 | + |
| 257 | + auto* scope = param_scope(); |
| 258 | + PADDLE_ENFORCE(scope); |
| 259 | + |
| 260 | + GraphPatternDetector gpd; |
| 261 | + auto* conv_input = |
| 262 | + gpd.mutable_pattern() |
| 263 | + ->NewNode(patterns::PDNodeName(name_scope_, "conv_input")) |
| 264 | + ->AsInput() |
| 265 | + ->assert_is_op_input("conv2d", "Input"); |
| 266 | + patterns::ConvBN conv_bn_pattern(gpd.mutable_pattern(), name_scope_); |
| 267 | + conv_bn_pattern(conv_input, true /*with_eltwise_add*/); |
| 268 | + |
| 269 | + int found_conv_bn_count = 0; |
| 270 | + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, |
| 271 | + Graph* g) { |
| 272 | + VLOG(4) << "handle ConvBN fuse"; |
| 273 | + |
| 274 | + // conv, batch_norm, |
| 275 | + // conv_weight, conv_out, |
| 276 | + // bn_scale, bn_bias, bn_mean, bn_variance, |
| 277 | + // bn_out, bn_mean_out, bn_variance_out, bn_saved_mean,bn_saved_variance |
| 278 | + GET_CONV_BN_NODES(conv_bn_pattern); |
| 279 | + // OPERATORS |
| 280 | + GET_IR_NODE_FROM_SUBGRAPH(eltwise, eltwise, conv_bn_pattern); |
| 281 | + // BIAS inputs |
| 282 | + GET_IR_NODE_FROM_SUBGRAPH(eltwise_y_in, eltwise_y_in, conv_bn_pattern); |
| 283 | + // BIAS outputs |
| 284 | + GET_IR_NODE_FROM_SUBGRAPH(eltwise_out, eltwise_out, conv_bn_pattern); |
| 285 | + |
| 286 | + // Get eltwise_y (conv bias) variable |
| 287 | + auto* eltwise_y_in_tensor = |
| 288 | + scope->FindVar(eltwise_y_in->Name())->GetMutable<LoDTensor>(); |
| 289 | + |
| 290 | + // Get batch norm bias |
| 291 | + auto* bn_bias_tensor = |
| 292 | + scope->FindVar(bn_bias->Name())->GetMutable<LoDTensor>(); |
| 293 | + |
| 294 | + // update weights and biases |
| 295 | + float epsilon = boost::get<float>(batch_norm->Op()->GetAttr("epsilon")); |
| 296 | + recompute_bias_and_weights(scope, conv_weight, *bn_scale, *bn_bias_tensor, |
| 297 | + *bn_mean, *bn_variance, eltwise_y_in_tensor, |
| 298 | + epsilon); |
| 299 | + |
| 300 | + // Update the elementwise_add node |
| 301 | + eltwise->Op()->SetAttr("axis", 1); |
| 302 | + eltwise->Op()->SetOutput("Out", std::vector<std::string>({bn_out->Name()})); |
| 303 | + |
| 304 | + GraphSafeRemoveNodes( |
| 305 | + graph.get(), |
| 306 | + {bn_scale, bn_bias, bn_mean, bn_variance, batch_norm, bn_mean_out, |
| 307 | + bn_variance_out, bn_saved_mean, bn_saved_variance, eltwise_out}); |
| 308 | + |
| 309 | + PADDLE_ENFORCE(subgraph.count(conv_input)); |
| 310 | + IR_NODE_LINK_TO(eltwise, bn_out); |
| 311 | + |
| 312 | + found_conv_bn_count++; |
| 313 | + }; |
| 314 | + |
| 315 | + gpd(graph.get(), handler); |
| 316 | + |
| 317 | + AddStatis(found_conv_bn_count); |
| 318 | + return graph; |
| 319 | +} |
| 320 | + |
| 321 | +} // namespace ir |
| 322 | +} // namespace framework |
| 323 | +} // namespace paddle |
| 324 | + |
| 325 | +REGISTER_PASS(conv_bn_fuse_pass, paddle::framework::ir::ConvBNFusePass); |
| 326 | +REGISTER_PASS(conv_eltwiseadd_bn_fuse_pass, |
| 327 | + paddle::framework::ir::ConvEltwiseAddBNFusePass); |
0 commit comments