Skip to content

Commit 7ab5626

Browse files
committed
- Added initial pass for embedding-fc-lstm
- Added draft of new operator - Added fused embedding fc lstm files - First time embedding_fc_lstm_fuse_pass was invoked in test_text_classification - Added Embedding pattern - Not crashing - Enabled draft of embedding_fc_lstm pass (does it job) - First working (Seqcompute only) version - Removed diagnostic comment - First enabling of BatchCompute - Disabling pass for embedding with is_sparse and is_distributed - Cosmetics - Style - Style
1 parent 4e81e22 commit 7ab5626

File tree

8 files changed

+976
-8
lines changed

8 files changed

+976
-8
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ endif()
3434
pass_library(attention_lstm_fuse_pass inference)
3535
pass_library(infer_clean_graph_pass inference)
3636
pass_library(fc_lstm_fuse_pass inference)
37+
pass_library(embedding_fc_lstm_fuse_pass inference)
3738
pass_library(fc_gru_fuse_pass inference)
3839
pass_library(seq_concat_fc_fuse_pass inference)
3940

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/lod_tensor.h"
18+
19+
#include "paddle/fluid/operators/math/blas.h"
20+
#include "paddle/fluid/operators/math/cpu_vec.h"
21+
#include "paddle/fluid/operators/math/fc_compute.h"
22+
#include "paddle/fluid/platform/cpu_info.h"
23+
24+
namespace paddle {
25+
namespace framework {
26+
namespace ir {
27+
28+
static int BuildFusion(Graph* graph, const std::string& name_scope,
29+
Scope* scope, bool with_fc_bias) {
30+
GraphPatternDetector gpd;
31+
auto* pattern = gpd.mutable_pattern();
32+
33+
// Build pattern
34+
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
35+
->assert_is_op_input("lookup_table")
36+
->assert_var_not_persistable();
37+
patterns::Embedding embedding_pattern(pattern, name_scope);
38+
// TODO(jczaja): Intermediate can only be for val that are not used anywhere
39+
// but lookup table output may go into other LSTM (for reverse
40+
// direction)
41+
auto* embedding_out = embedding_pattern(x);
42+
patterns::FC fc_pattern(pattern, name_scope);
43+
44+
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
45+
auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
46+
patterns::LSTM lstm_pattern(pattern, name_scope);
47+
lstm_pattern(fc_out);
48+
49+
// Create New OpDesc
50+
auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm,
51+
Node* input, Node* weight_x, Node* weight_h,
52+
Node* bias, Node* hidden, Node* cell,
53+
Node* xx, Node* fc_bias) {
54+
OpDesc op_desc;
55+
op_desc.SetType("fused_embedding_fc_lstm");
56+
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
57+
SET_IN(Ids, input);
58+
SET_IN(WeightH, weight_h);
59+
// Neet to have this passed as We need Wc data for peephole connections
60+
SET_IN(Bias, bias);
61+
#undef SET_IN
62+
63+
// Multiply embeddings with Weights
64+
PADDLE_ENFORCE(scope);
65+
const std::string& embeddings = patterns::UniqueKey("Embeddings");
66+
auto* embeddings_var = scope->Var(embeddings);
67+
PADDLE_ENFORCE(embeddings_var);
68+
auto* embeddings_tensor =
69+
embeddings_var->GetMutable<framework::LoDTensor>();
70+
// Get WeightX size: [single_embedding, fc_size]
71+
// and embedding size: [dict_size, single_embedding]
72+
// and create new size of embeddings eg. [dict_size , hidden_size]
73+
auto* embedding_var = scope->FindVar(W->Name());
74+
PADDLE_ENFORCE(embedding_var);
75+
const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
76+
77+
const auto& weightx_tensor =
78+
scope->FindVar(weight_x->Name())->Get<framework::LoDTensor>();
79+
embeddings_tensor->Resize(
80+
{embedding_tensor.dims()[0], weightx_tensor.dims()[1]});
81+
82+
// Multiplie embeddings via WeightsX and add bias
83+
auto embedding_data = embedding_tensor.data<float>();
84+
auto weightx_data = weightx_tensor.data<float>();
85+
auto embeddings_data =
86+
embeddings_tensor->mutable_data<float>(platform::CPUPlace());
87+
88+
// Adding biases to GEMM result to be
89+
auto* lstm_bias_var = scope->FindVar(bias->Name());
90+
PADDLE_ENFORCE(lstm_bias_var);
91+
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
92+
93+
auto alpha = 1.0f;
94+
auto beta = 1.0f;
95+
int m = embedding_tensor.dims()[0];
96+
int n = weightx_tensor.dims()[1];
97+
int k = embedding_tensor.dims()[1];
98+
99+
// Copy only gate biases values (only actual bias data, not peephole
100+
// weights)
101+
std::vector<float> combined_biases(n, 0.0f);
102+
memcpy(&combined_biases[0], lstm_bias_tensor.data<float>(),
103+
n * sizeof(float));
104+
105+
if (with_fc_bias) {
106+
// Add FC-bias with LSTM-bias (into GEMM result to be)
107+
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
108+
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
109+
for (int i = 0; i < fc_bias_tensor.numel(); i++) {
110+
combined_biases[i] =
111+
lstm_bias_tensor.data<float>()[i] + fc_bias_tensor.data<float>()[i];
112+
}
113+
}
114+
115+
// broadcast biases
116+
std::vector<float> ones(m, 1.0f);
117+
paddle::operators::math::CBlas<float>::GEMM(
118+
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
119+
&combined_biases[0], n, 0.0f, embeddings_data, n);
120+
121+
// Wx*embeddings
122+
paddle::operators::math::CBlas<float>::GEMM(
123+
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
124+
embedding_data, k, weightx_data, n, beta, embeddings_data, n);
125+
op_desc.SetInput("Embeddings", {embeddings});
126+
127+
// Create temp variables.
128+
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
129+
const std::string BatchedCellPreAct =
130+
patterns::UniqueKey("BatchedCellPreAct");
131+
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
132+
133+
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
134+
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
135+
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
136+
137+
op_desc.SetInput("H0", {});
138+
op_desc.SetInput("C0", {});
139+
op_desc.SetOutput("Hidden", {hidden->Name()});
140+
op_desc.SetOutput("Cell", {cell->Name()});
141+
op_desc.SetOutput("XX", {xx->Name()});
142+
op_desc.SetOutput("BatchedGate", {BatchedGate});
143+
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
144+
op_desc.SetOutput("BatchedInput", {BatchedInput});
145+
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
146+
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
147+
// TODO(TJ): get from attr
148+
op_desc.SetAttr("use_seq", true);
149+
150+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
151+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
152+
#define OP_SET_OUT(x) \
153+
const std::string x = patterns::UniqueKey(#x); \
154+
op_desc.SetOutput(#x, {x}); \
155+
scope->Var(x)->GetMutable<LoDTensor>()
156+
OP_SET_OUT(BatchedCell);
157+
OP_SET_OUT(BatchedHidden);
158+
OP_SET_OUT(ReorderedH0);
159+
OP_SET_OUT(ReorderedC0);
160+
#undef OP_SET_OUT
161+
162+
auto* op = graph->CreateOpNode(&op_desc);
163+
IR_NODE_LINK_TO(input, op);
164+
IR_NODE_LINK_TO(weight_x, op);
165+
IR_NODE_LINK_TO(weight_h, op);
166+
IR_NODE_LINK_TO(bias, op);
167+
IR_NODE_LINK_TO(op, hidden);
168+
return op;
169+
};
170+
171+
int fusion_count{0};
172+
173+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
174+
Graph* g) {
175+
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
176+
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
177+
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
178+
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
179+
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
180+
GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern);
181+
GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern);
182+
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
183+
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
184+
185+
// TODO(jczaja): Add support for is_sparse / is_distributed
186+
auto is_sparse = boost::get<bool>(lookup_table->Op()->GetAttr("is_sparse"));
187+
auto is_distributed =
188+
boost::get<bool>(lookup_table->Op()->GetAttr("is_distributed"));
189+
190+
if (is_sparse == true || is_distributed == true) {
191+
return;
192+
}
193+
194+
if (with_fc_bias) {
195+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
196+
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
197+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
198+
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
199+
Bias, Hidden, Cell, fc_out, fc_bias);
200+
// Remove unneeded nodes.
201+
// TODO(jczaja): Proper removing of loopup table
202+
std::unordered_set<const Node*> marked_nodes(
203+
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
204+
{mul, lstm, elementwise_add, fc_bias});
205+
GraphSafeRemoveNodes(graph, marked_nodes);
206+
} else {
207+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
208+
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
209+
Bias, Hidden, Cell, fc_out, nullptr);
210+
// Remove unneeded nodes.
211+
// TODO(jczaja): Proper removing of loopup table
212+
// std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
213+
// lstm});
214+
std::unordered_set<const Node*> marked_nodes({mul, lstm});
215+
GraphSafeRemoveNodes(graph, marked_nodes);
216+
}
217+
218+
++fusion_count;
219+
};
220+
221+
gpd(graph, handler);
222+
223+
return fusion_count;
224+
}
225+
226+
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
227+
std::unique_ptr<ir::Graph> graph) const {
228+
FusePassBase::Init(name_scope_, graph.get());
229+
230+
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
231+
true /*with_fc_bias*/);
232+
233+
AddStatis(fusion_count);
234+
return graph;
235+
}
236+
237+
} // namespace ir
238+
} // namespace framework
239+
} // namespace paddle
240+
241+
REGISTER_PASS(embedding_fc_lstm_fuse_pass,
242+
paddle::framework::ir::EmbeddingFCLSTMFusePass);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/graph.h"
19+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
// Fusing of Embedding , FC and LSTM op
26+
27+
// Just FC without bias
28+
class EmbeddingFCLSTMFusePass : public FusePassBase {
29+
public:
30+
virtual ~EmbeddingFCLSTMFusePass() {}
31+
32+
protected:
33+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
34+
35+
const std::string name_scope_{"embedding_fc_lstm_fuse"};
36+
};
37+
38+
} // namespace ir
39+
} // namespace framework
40+
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
692692
}
693693
}
694694

695+
PDNode *patterns::Embedding::operator()(PDNode *x) {
696+
x->assert_is_op_input("lookup_table", "Ids");
697+
auto *lookup_table_op =
698+
pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table");
699+
#define NEW_NODE(arg__, io__) \
700+
auto *arg__ = pattern->NewNode(arg__##_repr()) \
701+
->assert_is_op_##io__("lookup_table", #arg__);
702+
703+
NEW_NODE(W, input);
704+
705+
NEW_NODE(Out, output);
706+
#undef NEW_NODE
707+
708+
lookup_table_op->LinksFrom({x, W});
709+
lookup_table_op->LinksTo({Out});
710+
return Out;
711+
}
712+
695713
PDNode *patterns::LSTM::operator()(PDNode *x) {
696714
x->assert_is_op_input("lstm", "Input");
697715
auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,23 @@ struct FC : public PatternBase {
418418
PATTERN_DECL_NODE(Out);
419419
};
420420

421+
// Embedding
422+
struct Embedding : public PatternBase {
423+
Embedding(PDPattern* pattern, const std::string& name_scope)
424+
: PatternBase(pattern, name_scope, "embedding") {}
425+
426+
PDNode* operator()(PDNode* x);
427+
428+
// declare operator node's name
429+
PATTERN_DECL_NODE(lookup_table);
430+
// Inputs
431+
//
432+
PATTERN_DECL_NODE(Ids);
433+
PATTERN_DECL_NODE(W); // embeddings
434+
// Outputs
435+
PATTERN_DECL_NODE(Out);
436+
};
437+
421438
struct LSTM : public PatternBase {
422439
LSTM(PDPattern* pattern, const std::string& name_scope)
423440
: PatternBase(pattern, name_scope, "lstm") {}

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ class Analyzer : public OrderedRegistry<PassManager> {
6464
// larger fusion.
6565
const std::vector<std::string> all_ir_passes_{{
6666
// Manual update the passes here.
67-
"infer_clean_graph_pass", //
68-
"attention_lstm_fuse_pass", //
69-
"fc_lstm_fuse_pass", //
70-
"mul_lstm_fuse_pass", //
71-
"fc_gru_fuse_pass", //
72-
"mul_gru_fuse_pass", //
73-
"seq_concat_fc_fuse_pass", //
74-
"fc_fuse_pass", //
67+
"infer_clean_graph_pass", //
68+
"attention_lstm_fuse_pass", //
69+
"embedding_fc_lstm_fuse_pass", //
70+
"fc_lstm_fuse_pass", //
71+
"mul_lstm_fuse_pass", //
72+
"fc_gru_fuse_pass", //
73+
"mul_gru_fuse_pass", //
74+
"seq_concat_fc_fuse_pass", //
75+
"fc_fuse_pass", //
7576
#ifdef PADDLE_WITH_MKLDNN
7677
"conv_relu_mkldnn_fuse_pass", //
7778
#endif

0 commit comments

Comments
 (0)