|
| 1 | +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/fluid/framework/ir/attention_lstm_fuse_pass.h" |
| 16 | +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" |
| 17 | +#include "paddle/fluid/framework/ir/graph_viz_pass.h" |
| 18 | +#include "paddle/fluid/framework/lod_tensor.h" |
| 19 | +#include "paddle/fluid/inference/api/helper.h" |
| 20 | + |
| 21 | +namespace paddle { |
| 22 | +namespace framework { |
| 23 | +namespace ir { |
| 24 | + |
| 25 | +struct Param { |
| 26 | + std::string X = "concat_0.tmp_0"; |
| 27 | + std::string C0 = "cell_init"; |
| 28 | + std::string H0 = "hidden_init"; |
| 29 | + std::string AttentionWeight = "attention_fc.w_0"; |
| 30 | + std::string AttentionBias = "attention_fc.b_0"; |
| 31 | + std::string AttentionScalar = "attention_output.w_0"; |
| 32 | + std::string AttentionScalarBias = "attention_output.b_0"; |
| 33 | + std::string LSTMWeight = "attention_w.new"; |
| 34 | + std::string LSTMBias = "attention_b.new"; |
| 35 | + std::string Hidden = "array_to_lod_tensor_0.tmp_0"; |
| 36 | + std::string Cell = "at.cell.new"; |
| 37 | + std::string AttentionedX = "at.x.new"; |
| 38 | + std::string AttentionFCOut = "at.fc.new"; |
| 39 | + std::string LSTMX = "at.lstmx.new"; |
| 40 | + std::string LSTMOUT = "at.lstmout.new"; |
| 41 | +}; |
| 42 | + |
| 43 | +void PrepareParameters(Graph* graph, const Param& param); |
| 44 | + |
| 45 | +void FindWhileOp(Graph* graph) { |
| 46 | + GraphPatternDetector gpd; |
| 47 | + std::unordered_set<int> fused_external_ops( |
| 48 | + {35, 36, 37, 38, 43, 44, 49, 45, 46, 47, 41, 42, 53, 54, 48, |
| 49 | + 57, 55, 56, 52, 74, 80, 77, 78, 79, 50, 77, 39, 40, 51}); |
| 50 | + |
| 51 | + gpd.mutable_pattern()->NewNode( |
| 52 | + [&](Node* n) { return fused_external_ops.count(n->id()); }, "while"); |
| 53 | + |
| 54 | + if (!graph->Has(kGraphvizMarkedNodeAttr)) { |
| 55 | + graph->Set(kGraphvizMarkedNodeAttr, new GraphVizPass::marked_nodes_t); |
| 56 | + } |
| 57 | + auto& marked_nodes = |
| 58 | + graph->Get<GraphVizPass::marked_nodes_t>(kGraphvizMarkedNodeAttr); |
| 59 | + |
| 60 | + auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph, |
| 61 | + Graph* g) { |
| 62 | + auto* while_pat_node = gpd.pattern().RetriveNode("while"); |
| 63 | + auto* while_node = subgraph.at(while_pat_node); |
| 64 | + marked_nodes.insert(while_node); |
| 65 | + }; |
| 66 | + gpd(graph, handle); |
| 67 | + |
| 68 | + Param param; |
| 69 | + // Add AttentionLSTM node |
| 70 | + OpDesc op_desc; |
| 71 | + op_desc.SetType("attention_lstm"); |
| 72 | + |
| 73 | +#define OP_SET_IN(x) op_desc.SetInput(#x, {param.x}); |
| 74 | +#define OP_SET_OUT(x) op_desc.SetOutput(#x, {param.x}); |
| 75 | + OP_SET_IN(X); |
| 76 | + OP_SET_IN(C0); |
| 77 | + OP_SET_IN(H0); |
| 78 | + OP_SET_IN(AttentionWeight); |
| 79 | + OP_SET_IN(AttentionBias); |
| 80 | + OP_SET_IN(AttentionScalar); |
| 81 | + OP_SET_IN(AttentionScalarBias); |
| 82 | + OP_SET_IN(LSTMWeight); |
| 83 | + OP_SET_IN(LSTMBias); |
| 84 | + |
| 85 | + OP_SET_OUT(Hidden); |
| 86 | + OP_SET_OUT(Cell); |
| 87 | + OP_SET_OUT(AttentionedX); |
| 88 | + OP_SET_OUT(AttentionFCOut); |
| 89 | + OP_SET_OUT(LSTMX); |
| 90 | + OP_SET_OUT(LSTMOUT); |
| 91 | +#undef OP_SET_IN |
| 92 | +#undef OP_SET_OUT |
| 93 | + |
| 94 | + auto* X = graph->RetriveNode(34); |
| 95 | + auto* LSTMOUT = graph->RetriveNode(81); |
| 96 | + auto* cell_init = graph->RetriveNode(6); |
| 97 | + auto* hidden_init = graph->RetriveNode(8); |
| 98 | + |
| 99 | +#define LINK_TO(node0, node1) \ |
| 100 | + node0->outputs.push_back(node1); \ |
| 101 | + node1->inputs.push_back(node0); |
| 102 | + |
| 103 | + auto* lstm_op = graph->CreateOpNode(&op_desc); |
| 104 | + PrepareParameters(graph, param); |
| 105 | + |
| 106 | + LINK_TO(X, lstm_op); |
| 107 | + LINK_TO(cell_init, lstm_op); |
| 108 | + LINK_TO(hidden_init, lstm_op); |
| 109 | + LINK_TO(lstm_op, LSTMOUT); |
| 110 | + |
| 111 | + GraphSafeRemoveNodes(graph, marked_nodes); |
| 112 | +} |
| 113 | + |
| 114 | +#define CHECK_P1(x) PADDLE_ENFORCE_NOT_NULL(x); |
| 115 | +#define CHECK_P2(x0, x1) \ |
| 116 | + CHECK_P1(x0); \ |
| 117 | + CHECK_P1(x1); |
| 118 | +#define CHECK_P3(x0, x1, x2) \ |
| 119 | + CHECK_P2(x0, x1); \ |
| 120 | + CHECK_P1(x2); |
| 121 | +#define CHECK_P4(x0, x1, x2, x3) \ |
| 122 | + CHECK_P3(x0, x1, x2); \ |
| 123 | + CHECK_P1(x3); |
| 124 | +#define CHECK_P5(x0, x1, x2, x3, x4) \ |
| 125 | + CHECK_P4(x0, x1, x2, x3); \ |
| 126 | + CHECK_P1(x4); |
| 127 | + |
| 128 | +void PrepareLSTMWeight(const LoDTensor& W_forget_w0, |
| 129 | + const LoDTensor& W_forget_w1, |
| 130 | + const LoDTensor& W_input_w0, const LoDTensor& W_input_w1, |
| 131 | + const LoDTensor& W_output_w0, |
| 132 | + const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0, |
| 133 | + const LoDTensor& W_cell_w1, LoDTensor* out); |
| 134 | + |
| 135 | +void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, |
| 136 | + const LoDTensor& B_output, const LoDTensor& B_cell, |
| 137 | + LoDTensor* out); |
| 138 | + |
| 139 | +void PrepareParameters(Graph* graph, const Param& param) { |
| 140 | + // Check parameters |
| 141 | + PADDLE_ENFORCE(graph->Has(kParamScopeAttr)); |
| 142 | + auto* scope = graph->Get<Scope*>(kParamScopeAttr); |
| 143 | + |
| 144 | + // Create new parameters. |
| 145 | + scope->Var(param.LSTMWeight)->GetMutable<LoDTensor>(); |
| 146 | + scope->Var(param.LSTMBias)->GetMutable<LoDTensor>(); |
| 147 | + scope->Var(param.Hidden)->GetMutable<LoDTensor>(); |
| 148 | + scope->Var(param.Cell)->GetMutable<LoDTensor>(); |
| 149 | + scope->Var(param.AttentionedX)->GetMutable<LoDTensor>(); |
| 150 | + scope->Var(param.AttentionFCOut)->GetMutable<LoDTensor>(); |
| 151 | + scope->Var(param.LSTMX)->GetMutable<LoDTensor>(); |
| 152 | + scope->Var(param.LSTMOUT)->GetMutable<LoDTensor>(); |
| 153 | + |
| 154 | +#define GATE_W(name__) \ |
| 155 | + auto* W_##name__##_w0 = scope->FindVar(#name__ ".w_0"); \ |
| 156 | + auto* W_##name__##_w1 = scope->FindVar(#name__ ".w_1"); \ |
| 157 | + auto* W_##name__##_b0 = scope->FindVar(#name__ ".b_0"); \ |
| 158 | + CHECK_P3(W_##name__##_w0, W_##name__##_w1, W_##name__##_b0); \ |
| 159 | + VLOG(4) << #name__ "_w0" \ |
| 160 | + << " shape: " << W_##name__##_w0->Get<LoDTensor>().dims(); \ |
| 161 | + VLOG(4) << #name__ "_w1" \ |
| 162 | + << " shape: " << W_##name__##_w1->Get<LoDTensor>().dims(); \ |
| 163 | + VLOG(4) << #name__ "_b0" \ |
| 164 | + << " shape: " << W_##name__##_b0->Get<LoDTensor>().dims(); \ |
| 165 | + auto& W_##name__##_w0_t = W_##name__##_w0->Get<LoDTensor>(); \ |
| 166 | + auto& W_##name__##_w1_t = W_##name__##_w1->Get<LoDTensor>(); \ |
| 167 | + auto& W_##name__##_b0_t = W_##name__##_b0->Get<LoDTensor>(); |
| 168 | + |
| 169 | + GATE_W(forget); |
| 170 | + GATE_W(input); |
| 171 | + GATE_W(output); |
| 172 | + GATE_W(c); |
| 173 | +#undef GATE_W |
| 174 | + |
| 175 | + auto* attention_fc_w = scope->FindVar("attention_fc.w_0"); |
| 176 | + auto* attention_fc_b = scope->FindVar("attention_fc.b_0"); |
| 177 | + auto* attention_output_w = scope->FindVar("attention_output.w_0"); |
| 178 | + auto* attention_output_b = scope->FindVar("attention_output.b_0"); |
| 179 | + CHECK_P4(attention_fc_w, attention_fc_b, attention_output_w, |
| 180 | + attention_output_b); |
| 181 | + |
| 182 | + auto* lstm_weight = scope->Var(param.LSTMWeight); |
| 183 | + auto* lstm_weight_t = lstm_weight->GetMutable<LoDTensor>(); |
| 184 | + auto* lstm_bias = scope->Var(param.LSTMBias); |
| 185 | + auto* lstm_bias_t = lstm_bias->GetMutable<LoDTensor>(); |
| 186 | + |
| 187 | + // reshape attention_bias |
| 188 | + auto* attention_bias_t = |
| 189 | + scope->FindVar(param.AttentionBias)->GetMutable<LoDTensor>(); |
| 190 | + PADDLE_ENFORCE_EQ(attention_bias_t->dims().size(), 1); |
| 191 | + attention_bias_t->Resize(make_ddim({1, attention_bias_t->dims()[0]})); |
| 192 | + |
| 193 | + auto* attention_scalar_bias_t = |
| 194 | + scope->FindVar(param.AttentionScalarBias)->GetMutable<LoDTensor>(); |
| 195 | + attention_scalar_bias_t->Resize( |
| 196 | + make_ddim({1, attention_scalar_bias_t->dims()[0]})); |
| 197 | + |
| 198 | + PrepareLSTMWeight(W_forget_w0_t, W_forget_w1_t, W_input_w0_t, W_input_w1_t, |
| 199 | + W_output_w0_t, W_output_w1_t, W_c_w0_t, W_c_w1_t, |
| 200 | + lstm_weight_t); |
| 201 | + PrepareLSTMBias(W_forget_b0_t, W_input_b0_t, W_output_b0_t, W_c_b0_t, |
| 202 | + lstm_bias_t); |
| 203 | +} |
| 204 | + |
| 205 | +// Prepare parameters |
| 206 | +void PrepareLSTMWeight(const LoDTensor& W_forget_w0, |
| 207 | + const LoDTensor& W_forget_w1, |
| 208 | + const LoDTensor& W_input_w0, const LoDTensor& W_input_w1, |
| 209 | + const LoDTensor& W_output_w0, |
| 210 | + const LoDTensor& W_output_w1, const LoDTensor& W_cell_w0, |
| 211 | + const LoDTensor& W_cell_w1, LoDTensor* out) { |
| 212 | + int D = W_forget_w0.dims()[0]; |
| 213 | + int M = W_forget_w1.dims()[0]; |
| 214 | + out->Resize(make_ddim({D + M, 4 * D})); |
| 215 | + VLOG(3) << "LSTMWeight resized to " << out->dims(); |
| 216 | + |
| 217 | + float* out_data = out->mutable_data<float>(platform::CPUPlace()); |
| 218 | + std::array<const float*, 4> tensors( |
| 219 | + {W_forget_w0.data<float>(), W_input_w0.data<float>(), |
| 220 | + W_output_w0.data<float>(), W_cell_w0.data<float>()}); |
| 221 | + std::array<const float*, 4> tensors1( |
| 222 | + {W_forget_w1.data<float>(), W_input_w1.data<float>(), |
| 223 | + W_output_w1.data<float>(), W_cell_w1.data<float>()}); |
| 224 | + |
| 225 | + for (int row = 0; row < D; row++) { |
| 226 | + for (int col = 0; col < 4; col++) { |
| 227 | + float* dst = out_data + 4 * D * row + D * col; |
| 228 | + const float* src = tensors[col] + D * row; |
| 229 | + memcpy(dst, src, D * sizeof(float)); |
| 230 | + } |
| 231 | + } |
| 232 | + |
| 233 | + for (int row = 0; row < M; row++) { |
| 234 | + for (int col = 0; col < 4; col++) { |
| 235 | + float* dst = out_data + 4 * D * (D + row) + D * col; |
| 236 | + const float* src = tensors1[col] + D * row; |
| 237 | + memcpy(dst, src, D * sizeof(float)); |
| 238 | + } |
| 239 | + } |
| 240 | +} |
| 241 | + |
| 242 | +void PrepareLSTMBias(const LoDTensor& B_forget, const LoDTensor& B_input, |
| 243 | + const LoDTensor& B_output, const LoDTensor& B_cell, |
| 244 | + LoDTensor* out) { |
| 245 | + std::array<const float*, 4> tensors( |
| 246 | + {B_forget.data<float>(), B_input.data<float>(), B_output.data<float>(), |
| 247 | + B_cell.data<float>()}); |
| 248 | + |
| 249 | + PADDLE_ENFORCE_EQ(B_forget.dims().size(), 1); |
| 250 | + int D = B_forget.dims()[0]; |
| 251 | + out->Resize(make_ddim({1, 4 * D})); |
| 252 | + auto* out_data = out->mutable_data<float>(platform::CPUPlace()); |
| 253 | + for (size_t i = 0; i < tensors.size(); i++) { |
| 254 | + memcpy(out_data + D * i, tensors[i], D * sizeof(float)); |
| 255 | + } |
| 256 | +} |
| 257 | + |
| 258 | +// Parameters |
| 259 | + |
| 260 | +std::unique_ptr<ir::Graph> AttentionLSTMFusePass::ApplyImpl( |
| 261 | + std::unique_ptr<ir::Graph> graph) const { |
| 262 | + PDPattern external_pattern, subblock_pattern; |
| 263 | + |
| 264 | + FindWhileOp(graph.get()); |
| 265 | + return graph; |
| 266 | +} |
| 267 | + |
| 268 | +} // namespace ir |
| 269 | +} // namespace framework |
| 270 | +} // namespace paddle |
| 271 | + |
| 272 | +REGISTER_PASS(attention_lstm_fuse_pass, |
| 273 | + paddle::framework::ir::AttentionLSTMFusePass); |
0 commit comments