|
| 1 | +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/ir/fuse_multi_transformer_layer_pass.h" |
| 16 | + |
| 17 | +#include <string> |
| 18 | + |
| 19 | +#include "paddle/fluid/framework/convert_utils.h" |
| 20 | +#include "paddle/fluid/framework/lod_tensor.h" |
| 21 | +#include "paddle/fluid/framework/op_version_registry.h" |
| 22 | + |
| 23 | +namespace paddle { |
| 24 | +namespace framework { |
| 25 | +class Scope; |
| 26 | +} // namespace framework |
| 27 | +} // namespace paddle |
| 28 | + |
| 29 | +namespace paddle { |
| 30 | +namespace framework { |
| 31 | +namespace ir { |
| 32 | +namespace patterns { |
| 33 | + |
| 34 | +std::unordered_map<std::string, std::string> |
| 35 | +MultiTransformerLayerPattern::operator()(bool enable_int8, |
| 36 | + int num_fused_op, |
| 37 | + bool is_decoder) { |
| 38 | + std::string fused_multi_transformer_name = |
| 39 | + enable_int8 ? "fused_multi_transformer_int8" : "fused_multi_transformer"; |
| 40 | + |
| 41 | + std::unordered_map<std::string, std::string> node_reprs; |
| 42 | + |
| 43 | + // x0 and src_mask is unqiue input of subgraph |
| 44 | + auto* x0 = pattern->NewNode(x0_repr()); |
| 45 | + x0->assert_is_op_input(fused_multi_transformer_name, "X")->AsInput(); |
| 46 | + auto* src_mask = pattern->NewNode(src_mask_repr()); |
| 47 | + src_mask->assert_is_op_input(fused_multi_transformer_name, "SrcMask") |
| 48 | + ->AsInput(); |
| 49 | + |
| 50 | + for (int i = 0; i < num_fused_op; ++i) { |
| 51 | + auto fuse_op_repr = |
| 52 | + PDNodeName(name_scope_, repr_, id_, "fuse_op_" + std::to_string(i)); |
| 53 | + node_reprs["fuse_op_" + std::to_string(i)] = fuse_op_repr; |
| 54 | + auto* fused_multi_transformer = |
| 55 | + pattern->NewNode(fuse_op_repr) |
| 56 | + ->assert_is_op(fused_multi_transformer_name); |
| 57 | + |
| 58 | + auto out_repr = |
| 59 | + PDNodeName(name_scope_, repr_, id_, "out_" + std::to_string(i)); |
| 60 | + node_reprs["out_" + std::to_string(i)] = out_repr; |
| 61 | + auto* out = pattern->NewNode(out_repr)->assert_is_op_output( |
| 62 | + fused_multi_transformer_name, "Out"); |
| 63 | + |
| 64 | + if (is_decoder) { |
| 65 | + auto shape_repr = |
| 66 | + PDNodeName(name_scope_, repr_, id_, "shape_" + std::to_string(i)); |
| 67 | + node_reprs["shape_" + std::to_string(i)] = shape_repr; |
| 68 | + auto* shape = pattern->NewNode(shape_repr)->assert_is_op("shape"); |
| 69 | + |
| 70 | + auto shape_out_repr = |
| 71 | + PDNodeName(name_scope_, repr_, id_, "shape_out_" + std::to_string(i)); |
| 72 | + node_reprs["shape_out_" + std::to_string(i)] = shape_out_repr; |
| 73 | + auto* shape_out = |
| 74 | + pattern->NewNode(shape_out_repr)->assert_is_op_output("shape", "Out"); |
| 75 | + |
| 76 | + shape->LinksFrom({src_mask}).LinksTo({shape_out}); |
| 77 | + |
| 78 | + auto slice_repr = |
| 79 | + PDNodeName(name_scope_, repr_, id_, "slice_" + std::to_string(i)); |
| 80 | + node_reprs["slice_" + std::to_string(i)] = slice_repr; |
| 81 | + auto* slice = pattern->NewNode(slice_repr)->assert_is_op("slice"); |
| 82 | + |
| 83 | + auto slice_out_repr = |
| 84 | + PDNodeName(name_scope_, repr_, id_, "slice_out_" + std::to_string(i)); |
| 85 | + node_reprs["slice_out_" + std::to_string(i)] = slice_out_repr; |
| 86 | + auto* slice_out = |
| 87 | + pattern->NewNode(slice_out_repr)->assert_is_op_output("slice", "Out"); |
| 88 | + |
| 89 | + slice->LinksFrom({shape_out}).LinksTo({slice_out}); |
| 90 | + |
| 91 | + fused_multi_transformer->LinksFrom({x0, src_mask, slice_out}) |
| 92 | + .LinksTo({out}); |
| 93 | + } else { |
| 94 | + auto cache_kv_repr = |
| 95 | + PDNodeName(name_scope_, repr_, id_, "cache_kv_" + std::to_string(i)); |
| 96 | + node_reprs["cache_kv_" + std::to_string(i)] = cache_kv_repr; |
| 97 | + auto* cache_kv = pattern->NewNode(cache_kv_repr); |
| 98 | + cache_kv->assert_is_op_input(fused_multi_transformer_name, "CacheKV"); |
| 99 | + cache_kv->AsInput(); |
| 100 | + |
| 101 | + auto fill_const_op_repr = |
| 102 | + PDNodeName(name_scope_, repr_, id_, "fill_op_" + std::to_string(i)); |
| 103 | + node_reprs["fill_op_" + std::to_string(i)] = fill_const_op_repr; |
| 104 | + auto fill_const_op = pattern->NewNode(fill_const_op_repr) |
| 105 | + ->assert_is_op("fill_constant_batch_size_like"); |
| 106 | + |
| 107 | + fused_multi_transformer->LinksFrom({x0, src_mask, cache_kv}) |
| 108 | + .LinksTo({out}); |
| 109 | + fill_const_op->LinksFrom({x0}).LinksTo({cache_kv}); |
| 110 | + } |
| 111 | + x0 = out; |
| 112 | + } |
| 113 | + x0->AsOutput(); |
| 114 | + return node_reprs; |
| 115 | +} |
| 116 | +} // namespace patterns |
| 117 | + |
| 118 | +inline void MergeInput(OpDesc* op, |
| 119 | + const std::vector<VariableNameMap>& input_name_maps, |
| 120 | + const std::string& input_name) { |
| 121 | + std::vector<std::string> tmp = input_name_maps[0].at(input_name); |
| 122 | + for (size_t i = 1; i < input_name_maps.size(); ++i) { |
| 123 | + tmp.insert(tmp.end(), |
| 124 | + input_name_maps[i].at(input_name).begin(), |
| 125 | + input_name_maps[i].at(input_name).end()); |
| 126 | + } |
| 127 | + op->SetInput(input_name, tmp); |
| 128 | +} |
| 129 | + |
| 130 | +template <typename T> |
| 131 | +inline void MergeAttrs(const std::vector<OpDesc*>& ops, |
| 132 | + const std::string& attr_name) { |
| 133 | + std::vector<T> res; |
| 134 | + for (size_t i = 0; i < ops.size(); ++i) { |
| 135 | + auto scale_vec = |
| 136 | + PADDLE_GET_CONST(std::vector<T>, ops[i]->GetAttr(attr_name)); |
| 137 | + res.insert(res.end(), scale_vec.begin(), scale_vec.end()); |
| 138 | + } |
| 139 | + ops[0]->SetAttr(attr_name, res); |
| 140 | +} |
| 141 | + |
| 142 | +int FuseMultiTransformerLayerPass::BuildFusion(Graph* graph, |
| 143 | + const std::string& name_scope, |
| 144 | + Scope* scope) const { |
| 145 | + GraphPatternDetector gpd; |
| 146 | + auto* pattern = gpd.mutable_pattern(); |
| 147 | + |
| 148 | + // TODO(wufeisheng): Get enable_int8 attr from graph after |
| 149 | + // fused_multi_transformer pass with int8 merged |
| 150 | + bool enable_int8 = false; |
| 151 | + |
| 152 | + int num_fuse_op = 0; |
| 153 | + bool is_decoder = false; |
| 154 | + |
| 155 | + if (graph->Has(kFusedMultiTransformerEncoderFusionCount)) { |
| 156 | + num_fuse_op = graph->Get<int>(kFusedMultiTransformerEncoderFusionCount); |
| 157 | + is_decoder = false; |
| 158 | + } else if (graph->Has(kFusedMultiTransformerDecoderFusionCount)) { |
| 159 | + num_fuse_op = graph->Get<int>(kFusedMultiTransformerDecoderFusionCount); |
| 160 | + is_decoder = true; |
| 161 | + } |
| 162 | + if (num_fuse_op == 0) { |
| 163 | + VLOG(4) << "fuse_multi_transformer_layer_pass will be skipped " |
| 164 | + "cause num_fuse_op is not been set or set to 0"; |
| 165 | + return 0; |
| 166 | + } |
| 167 | + if (!is_decoder) { |
| 168 | + VLOG(4) << "fuse_multi_transformer_layer_pass will match encoder pattern"; |
| 169 | + } else { |
| 170 | + VLOG(4) << "fuse_multi_transformer_layer_pass will match decoder pattern"; |
| 171 | + } |
| 172 | + |
| 173 | + patterns::MultiTransformerLayerPattern multi_layer_pattern(pattern, |
| 174 | + name_scope); |
| 175 | + auto node_reprs = multi_layer_pattern(enable_int8, num_fuse_op, is_decoder); |
| 176 | + |
| 177 | + int fusion_count{0}; |
| 178 | + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, |
| 179 | + Graph* graph) { |
| 180 | + /////////////////// |
| 181 | + //// Get nodes //// |
| 182 | + /////////////////// |
| 183 | + |
| 184 | + GET_IR_NODE_FROM_SUBGRAPH(src_mask, src_mask, multi_layer_pattern); |
| 185 | + GET_IR_NODE_FROM_SUBGRAPH(x0, x0, multi_layer_pattern); |
| 186 | + |
| 187 | + std::vector<Node*> fuse_op_nodes; |
| 188 | + std::vector<Node*> out_nodes; |
| 189 | + |
| 190 | + std::vector<std::string> unused_node_prefixes = { |
| 191 | + "shape_", "shape_out_", "slice_", "slice_out_"}; |
| 192 | + std::vector<Node*> unused_nodes; |
| 193 | + |
| 194 | + std::vector<OpDesc*> fuse_op_descs; |
| 195 | + std::vector<VariableNameMap> fuse_op_input_var_name_maps; |
| 196 | + std::vector<VariableNameMap> fuse_op_output_var_name_maps; |
| 197 | + |
| 198 | + for (int i = 0; i < num_fuse_op; ++i) { |
| 199 | + PDNode* fuse_op_pdnode = |
| 200 | + multi_layer_pattern.PatternBase::pattern->RetrieveNode( |
| 201 | + node_reprs["fuse_op_" + std::to_string(i)]); |
| 202 | + Node* fuse_op_node = subgraph.at(fuse_op_pdnode); |
| 203 | + fuse_op_nodes.push_back(fuse_op_node); |
| 204 | + fuse_op_descs.push_back(fuse_op_node->Op()); |
| 205 | + fuse_op_input_var_name_maps.emplace_back(fuse_op_node->Op()->Inputs()); |
| 206 | + fuse_op_output_var_name_maps.emplace_back(fuse_op_node->Op()->Outputs()); |
| 207 | + |
| 208 | + PDNode* out_pdnode = |
| 209 | + multi_layer_pattern.PatternBase::pattern->RetrieveNode( |
| 210 | + node_reprs["out_" + std::to_string(i)]); |
| 211 | + out_nodes.push_back(subgraph.at(out_pdnode)); |
| 212 | + |
| 213 | + // fill_const op use x0 as input |
| 214 | + if (!is_decoder && i != 0) { |
| 215 | + PDNode* fill_op_pdnode = |
| 216 | + multi_layer_pattern.PatternBase::pattern->RetrieveNode( |
| 217 | + node_reprs["fill_op_" + std::to_string(i)]); |
| 218 | + Node* fill_op_node = subgraph.at(fill_op_pdnode); |
| 219 | + fill_op_node->Op()->SetInput("Input", {x0->Name()}); |
| 220 | + IR_NODE_UNLINK(out_nodes[i - 1], fill_op_node); |
| 221 | + IR_NODE_LINK_TO(x0, fill_op_node); |
| 222 | + } else if (is_decoder && i != 0) { |
| 223 | + for (const auto& unused_node_prefix : unused_node_prefixes) { |
| 224 | + PDNode* unused_pdnode = |
| 225 | + multi_layer_pattern.PatternBase::pattern->RetrieveNode( |
| 226 | + node_reprs[unused_node_prefix + std::to_string(i)]); |
| 227 | + Node* unused_node = subgraph.at(unused_pdnode); |
| 228 | + unused_nodes.push_back(unused_node); |
| 229 | + } |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + /////////////// |
| 234 | + //// Merge //// |
| 235 | + /////////////// |
| 236 | + |
| 237 | + // Merge inputs |
| 238 | + std::vector<std::string> inputs_names = {"CacheKV", |
| 239 | + "FFN1Bias", |
| 240 | + "FFN1Weight", |
| 241 | + "FFN2Bias", |
| 242 | + "FFN2Weight", |
| 243 | + "FFNLnBias", |
| 244 | + "FFNLnScale", |
| 245 | + "LnBias", |
| 246 | + "LnScale", |
| 247 | + "OutLinearBias", |
| 248 | + "OutLinearW", |
| 249 | + "QKVBias", |
| 250 | + "QKVW"}; |
| 251 | + |
| 252 | + for (const auto& input_name : inputs_names) { |
| 253 | + MergeInput(fuse_op_descs[0], fuse_op_input_var_name_maps, input_name); |
| 254 | + } |
| 255 | + |
| 256 | + // Merge outputs |
| 257 | + fuse_op_descs[0]->SetOutput( |
| 258 | + "Out", fuse_op_output_var_name_maps[num_fuse_op - 1]["Out"]); |
| 259 | + auto& merged_cache_kv_out_names = |
| 260 | + fuse_op_output_var_name_maps[0]["CacheKVOut"]; |
| 261 | + for (int i = 1; i < num_fuse_op; ++i) { |
| 262 | + const auto& out_var_names = fuse_op_output_var_name_maps[i]["CacheKVOut"]; |
| 263 | + merged_cache_kv_out_names.insert(merged_cache_kv_out_names.end(), |
| 264 | + out_var_names.begin(), |
| 265 | + out_var_names.end()); |
| 266 | + } |
| 267 | + fuse_op_descs[0]->SetOutput("CacheKVOut", merged_cache_kv_out_names); |
| 268 | + |
| 269 | + //////////////// |
| 270 | + //// ReLink //// |
| 271 | + //////////////// |
| 272 | + // Before relink, out nodes (0 -> num_layer-1) should be removed |
| 273 | + std::unordered_set<const Node*> marked_out_nodes(out_nodes.begin(), |
| 274 | + out_nodes.end() - 1); |
| 275 | + GraphSafeRemoveNodes(graph, marked_out_nodes); |
| 276 | + |
| 277 | + // Relink all input nodes of fused_multi_transformer ops to the first op |
| 278 | + auto& merged_inputs = fuse_op_nodes[0]->inputs; |
| 279 | + for (int i = 1; i < num_fuse_op; ++i) { |
| 280 | + merged_inputs.insert(merged_inputs.end(), |
| 281 | + fuse_op_nodes[i]->inputs.begin(), |
| 282 | + fuse_op_nodes[i]->inputs.end()); |
| 283 | + } |
| 284 | + |
| 285 | + // Relink fuse op -> out |
| 286 | + IR_NODE_UNLINK(fuse_op_nodes[num_fuse_op - 1], out_nodes[num_fuse_op - 1]); |
| 287 | + IR_NODE_LINK_TO(fuse_op_nodes[0], out_nodes[num_fuse_op - 1]); |
| 288 | + |
| 289 | + ///////////////////////////// |
| 290 | + //// Delete unused nodes //// |
| 291 | + ///////////////////////////// |
| 292 | + // Delete fused_multi_transformer op expect for the first one |
| 293 | + std::unordered_set<const Node*> marked_fuse_op_nodes( |
| 294 | + fuse_op_nodes.begin() + 1, fuse_op_nodes.end()); |
| 295 | + |
| 296 | + if (is_decoder) { |
| 297 | + marked_fuse_op_nodes.insert(unused_nodes.begin(), unused_nodes.end()); |
| 298 | + } |
| 299 | + |
| 300 | + GraphSafeRemoveNodes(graph, marked_fuse_op_nodes); |
| 301 | + ++fusion_count; |
| 302 | + }; |
| 303 | + |
| 304 | + gpd(graph, handler); |
| 305 | + return fusion_count; |
| 306 | +} |
| 307 | + |
| 308 | +void FuseMultiTransformerLayerPass::ApplyImpl(Graph* graph) const { |
| 309 | + FusePassBase::Init(name_scope_, graph); |
| 310 | + auto* scope = param_scope(); |
| 311 | + PADDLE_ENFORCE_NOT_NULL( |
| 312 | + scope, |
| 313 | + platform::errors::Fatal("During the fuse_multi_transformer_layer pass, " |
| 314 | + "The scope should not be null.")); |
| 315 | + int fusion_count = BuildFusion(graph, name_scope_, scope); |
| 316 | + |
| 317 | + AddStatis(fusion_count); |
| 318 | +} |
| 319 | + |
| 320 | +} // namespace ir |
| 321 | +} // namespace framework |
| 322 | +} // namespace paddle |
| 323 | + |
| 324 | +REGISTER_PASS(fuse_multi_transformer_layer_pass, |
| 325 | + paddle::framework::ir::FuseMultiTransformerLayerPass); |
0 commit comments