Skip to content

Commit ede4b23

Browse files
authored
Merge pull request #13553 from jczaja/prv-fused_embedding_fc_lstm_op
Adding fused_embedding_fc_lstm op
2 parents 618b329 + e202f33 commit ede4b23

10 files changed

+987
-9
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ endif ()
3434
pass_library(attention_lstm_fuse_pass inference)
3535
pass_library(infer_clean_graph_pass inference)
3636
pass_library(fc_lstm_fuse_pass inference)
37+
pass_library(embedding_fc_lstm_fuse_pass inference)
3738
pass_library(fc_gru_fuse_pass inference)
3839
pass_library(seq_concat_fc_fuse_pass inference)
3940

Lines changed: 243 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,243 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/embedding_fc_lstm_fuse_pass.h"
16+
#include <algorithm>
17+
#include <string>
18+
#include "paddle/fluid/framework/lod_tensor.h"
19+
20+
#include "paddle/fluid/operators/math/blas.h"
21+
#include "paddle/fluid/operators/math/cpu_vec.h"
22+
#include "paddle/fluid/operators/math/fc_compute.h"
23+
#include "paddle/fluid/platform/cpu_info.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
namespace ir {
28+
29+
static int BuildFusion(Graph* graph, const std::string& name_scope,
30+
Scope* scope, bool with_fc_bias) {
31+
GraphPatternDetector gpd;
32+
auto* pattern = gpd.mutable_pattern();
33+
34+
// Build pattern
35+
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "x"))
36+
->assert_is_op_input("lookup_table")
37+
->assert_var_not_persistable();
38+
patterns::Embedding embedding_pattern(pattern, name_scope);
39+
// TODO(jczaja): Intermediate can only be for val that are not used anywhere
40+
// but lookup table output may go into other LSTM (for reverse
41+
// direction)
42+
auto* embedding_out = embedding_pattern(x);
43+
patterns::FC fc_pattern(pattern, name_scope);
44+
45+
// fc_out is a tmp var, will be removed after fuse, so marked as intermediate.
46+
auto* fc_out = fc_pattern(embedding_out, with_fc_bias)->AsIntermediate();
47+
patterns::LSTM lstm_pattern(pattern, name_scope);
48+
lstm_pattern(fc_out);
49+
50+
// Create New OpDesc
51+
auto embedding_lstm_creator = [&](Node* embedding, Node* W, Node* lstm,
52+
Node* input, Node* weight_x, Node* weight_h,
53+
Node* bias, Node* hidden, Node* cell,
54+
Node* xx, Node* fc_bias) {
55+
OpDesc op_desc;
56+
op_desc.SetType("fused_embedding_fc_lstm");
57+
#define SET_IN(Key, node__) op_desc.SetInput(#Key, {node__->Name()});
58+
SET_IN(Ids, input);
59+
SET_IN(WeightH, weight_h);
60+
// Neet to have this passed as We need Wc data for peephole connections
61+
SET_IN(Bias, bias);
62+
#undef SET_IN
63+
64+
// Multiply embeddings with Weights
65+
PADDLE_ENFORCE(scope);
66+
const std::string& embeddings = patterns::UniqueKey("Embeddings");
67+
auto* embeddings_var = scope->Var(embeddings);
68+
PADDLE_ENFORCE(embeddings_var);
69+
auto* embeddings_tensor =
70+
embeddings_var->GetMutable<framework::LoDTensor>();
71+
// Get WeightX size: [single_embedding, fc_size]
72+
// and embedding size: [dict_size, single_embedding]
73+
// and create new size of embeddings eg. [dict_size , hidden_size]
74+
auto* embedding_var = scope->FindVar(W->Name());
75+
PADDLE_ENFORCE(embedding_var);
76+
const auto& embedding_tensor = embedding_var->Get<framework::LoDTensor>();
77+
78+
const auto& weightx_tensor =
79+
scope->FindVar(weight_x->Name())->Get<framework::LoDTensor>();
80+
embeddings_tensor->Resize(
81+
{embedding_tensor.dims()[0], weightx_tensor.dims()[1]});
82+
83+
// Multiplie embeddings via WeightsX and add bias
84+
auto embedding_data = embedding_tensor.data<float>();
85+
auto weightx_data = weightx_tensor.data<float>();
86+
auto embeddings_data =
87+
embeddings_tensor->mutable_data<float>(platform::CPUPlace());
88+
89+
// Adding biases to GEMM result to be
90+
auto* lstm_bias_var = scope->FindVar(bias->Name());
91+
PADDLE_ENFORCE(lstm_bias_var);
92+
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
93+
94+
auto alpha = 1.0f;
95+
auto beta = 1.0f;
96+
int m = embedding_tensor.dims()[0];
97+
int n = weightx_tensor.dims()[1];
98+
int k = embedding_tensor.dims()[1];
99+
100+
// Copy only gate biases values (only actual bias data, not peephole
101+
// weights)
102+
std::vector<float> combined_biases;
103+
combined_biases.reserve(n);
104+
std::copy_n(lstm_bias_tensor.data<float>(), n,
105+
std::back_inserter(combined_biases));
106+
107+
if (with_fc_bias) {
108+
// Add FC-bias with LSTM-bias (into GEMM result to be)
109+
auto* fc_bias_var = scope->FindVar(fc_bias->Name());
110+
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
111+
for (int i = 0; i < fc_bias_tensor.numel(); i++) {
112+
combined_biases[i] += fc_bias_tensor.data<float>()[i];
113+
}
114+
}
115+
116+
// broadcast biases
117+
std::vector<float> ones(m, 1.0f);
118+
paddle::operators::math::CBlas<float>::GEMM(
119+
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, 1, alpha, &ones[0], 1,
120+
&combined_biases[0], n, 0.0f, embeddings_data, n);
121+
122+
// Wx*embeddings + biases
123+
paddle::operators::math::CBlas<float>::GEMM(
124+
CblasRowMajor, CblasNoTrans, CblasNoTrans, m, n, k, alpha,
125+
embedding_data, k, weightx_data, n, beta, embeddings_data, n);
126+
op_desc.SetInput("Embeddings", {embeddings});
127+
128+
// Create temp variables.
129+
const std::string BatchedInput = patterns::UniqueKey("BatchedInput");
130+
const std::string BatchedCellPreAct =
131+
patterns::UniqueKey("BatchedCellPreAct");
132+
const std::string BatchedGate = patterns::UniqueKey("BatchedGate");
133+
134+
scope->Var(BatchedInput)->GetMutable<framework::LoDTensor>();
135+
scope->Var(BatchedCellPreAct)->GetMutable<framework::LoDTensor>();
136+
scope->Var(BatchedGate)->GetMutable<framework::LoDTensor>();
137+
138+
op_desc.SetInput("H0", {});
139+
op_desc.SetInput("C0", {});
140+
op_desc.SetOutput("Hidden", {hidden->Name()});
141+
op_desc.SetOutput("Cell", {cell->Name()});
142+
op_desc.SetOutput("XX", {xx->Name()});
143+
op_desc.SetOutput("BatchedGate", {BatchedGate});
144+
op_desc.SetOutput("BatchCellPreAct", {BatchedCellPreAct});
145+
op_desc.SetOutput("BatchedInput", {BatchedInput});
146+
op_desc.SetAttr("is_reverse", lstm->Op()->GetAttr("is_reverse"));
147+
op_desc.SetAttr("use_peepholes", lstm->Op()->GetAttr("use_peepholes"));
148+
// TODO(TJ): get from attr
149+
op_desc.SetAttr("use_seq", true);
150+
151+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
152+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
153+
#define OP_SET_OUT(x) \
154+
const std::string x = patterns::UniqueKey(#x); \
155+
op_desc.SetOutput(#x, {x}); \
156+
scope->Var(x)->GetMutable<LoDTensor>()
157+
OP_SET_OUT(BatchedCell);
158+
OP_SET_OUT(BatchedHidden);
159+
OP_SET_OUT(ReorderedH0);
160+
OP_SET_OUT(ReorderedC0);
161+
#undef OP_SET_OUT
162+
163+
auto* op = graph->CreateOpNode(&op_desc);
164+
IR_NODE_LINK_TO(input, op);
165+
IR_NODE_LINK_TO(weight_x, op);
166+
IR_NODE_LINK_TO(weight_h, op);
167+
IR_NODE_LINK_TO(bias, op);
168+
IR_NODE_LINK_TO(op, hidden);
169+
return op;
170+
};
171+
172+
int fusion_count{0};
173+
174+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
175+
Graph* g) {
176+
GET_IR_NODE_FROM_SUBGRAPH(lstm, lstm, lstm_pattern);
177+
GET_IR_NODE_FROM_SUBGRAPH(Weight, Weight, lstm_pattern);
178+
GET_IR_NODE_FROM_SUBGRAPH(Bias, Bias, lstm_pattern);
179+
GET_IR_NODE_FROM_SUBGRAPH(Cell, Cell, lstm_pattern);
180+
GET_IR_NODE_FROM_SUBGRAPH(Hidden, Hidden, lstm_pattern);
181+
GET_IR_NODE_FROM_SUBGRAPH(lookup_table, lookup_table, embedding_pattern);
182+
GET_IR_NODE_FROM_SUBGRAPH(W, W, embedding_pattern);
183+
GET_IR_NODE_FROM_SUBGRAPH(w, w, fc_pattern);
184+
GET_IR_NODE_FROM_SUBGRAPH(mul, mul, fc_pattern);
185+
186+
// TODO(jczaja): Add support for is_sparse / is_distributed
187+
auto is_sparse = boost::get<bool>(lookup_table->Op()->GetAttr("is_sparse"));
188+
auto is_distributed =
189+
boost::get<bool>(lookup_table->Op()->GetAttr("is_distributed"));
190+
191+
if (is_sparse == true || is_distributed == true) {
192+
return;
193+
}
194+
195+
if (with_fc_bias) {
196+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, Out, fc_pattern);
197+
GET_IR_NODE_FROM_SUBGRAPH(fc_bias, bias, fc_pattern);
198+
GET_IR_NODE_FROM_SUBGRAPH(elementwise_add, elementwise_add, fc_pattern);
199+
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
200+
Bias, Hidden, Cell, fc_out, fc_bias);
201+
// Remove unneeded nodes.
202+
// TODO(jczaja): Proper removing of lookup table
203+
std::unordered_set<const Node*> marked_nodes(
204+
//{lookup_table, mul, lstm, elementwise_add, fc_bias, W});
205+
{mul, lstm, elementwise_add, fc_bias});
206+
GraphSafeRemoveNodes(graph, marked_nodes);
207+
} else {
208+
GET_IR_NODE_FROM_SUBGRAPH(fc_out, mul_out, fc_pattern);
209+
embedding_lstm_creator(lookup_table, W, lstm, subgraph.at(x), w, Weight,
210+
Bias, Hidden, Cell, fc_out, nullptr);
211+
// Remove unneeded nodes.
212+
// TODO(jczaja): Proper removing of lookup table
213+
// std::unordered_set<const Node*> marked_nodes({lookup_table, W, mul,
214+
// lstm});
215+
std::unordered_set<const Node*> marked_nodes({mul, lstm});
216+
GraphSafeRemoveNodes(graph, marked_nodes);
217+
}
218+
219+
++fusion_count;
220+
};
221+
222+
gpd(graph, handler);
223+
224+
return fusion_count;
225+
}
226+
227+
std::unique_ptr<ir::Graph> EmbeddingFCLSTMFusePass::ApplyImpl(
228+
std::unique_ptr<ir::Graph> graph) const {
229+
FusePassBase::Init(name_scope_, graph.get());
230+
231+
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
232+
true /*with_fc_bias*/);
233+
234+
AddStatis(fusion_count);
235+
return graph;
236+
}
237+
238+
} // namespace ir
239+
} // namespace framework
240+
} // namespace paddle
241+
242+
REGISTER_PASS(embedding_fc_lstm_fuse_pass,
243+
paddle::framework::ir::EmbeddingFCLSTMFusePass);
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/graph.h"
19+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
// Fusing of Embedding , FC and LSTM op
26+
27+
// Just FC without bias
28+
class EmbeddingFCLSTMFusePass : public FusePassBase {
29+
public:
30+
virtual ~EmbeddingFCLSTMFusePass() {}
31+
32+
protected:
33+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
34+
35+
const std::string name_scope_{"embedding_fc_lstm_fuse"};
36+
};
37+
38+
} // namespace ir
39+
} // namespace framework
40+
} // namespace paddle

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,24 @@ PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
692692
}
693693
}
694694

695+
PDNode *patterns::Embedding::operator()(PDNode *x) {
696+
x->assert_is_op_input("lookup_table", "Ids");
697+
auto *lookup_table_op =
698+
pattern->NewNode(lookup_table_repr())->assert_is_op("lookup_table");
699+
#define NEW_NODE(arg__, io__) \
700+
auto *arg__ = pattern->NewNode(arg__##_repr()) \
701+
->assert_is_op_##io__("lookup_table", #arg__);
702+
703+
NEW_NODE(W, input);
704+
705+
NEW_NODE(Out, output);
706+
#undef NEW_NODE
707+
708+
lookup_table_op->LinksFrom({x, W});
709+
lookup_table_op->LinksTo({Out});
710+
return Out;
711+
}
712+
695713
PDNode *patterns::LSTM::operator()(PDNode *x) {
696714
x->assert_is_op_input("lstm", "Input");
697715
auto *lstm_op = pattern->NewNode(lstm_repr())->assert_is_op("lstm");

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,23 @@ struct FC : public PatternBase {
418418
PATTERN_DECL_NODE(Out);
419419
};
420420

421+
// Embedding
422+
struct Embedding : public PatternBase {
423+
Embedding(PDPattern* pattern, const std::string& name_scope)
424+
: PatternBase(pattern, name_scope, "embedding") {}
425+
426+
PDNode* operator()(PDNode* x);
427+
428+
// declare operator node's name
429+
PATTERN_DECL_NODE(lookup_table);
430+
// Inputs
431+
//
432+
PATTERN_DECL_NODE(Ids);
433+
PATTERN_DECL_NODE(W); // embeddings
434+
// Outputs
435+
PATTERN_DECL_NODE(Out);
436+
};
437+
421438
struct LSTM : public PatternBase {
422439
LSTM(PDPattern* pattern, const std::string& name_scope)
423440
: PatternBase(pattern, name_scope, "lstm") {}

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,15 @@ class Analyzer : public OrderedRegistry<PassManager> {
6464
// larger fusion.
6565
const std::vector<std::string> all_ir_passes_{{
6666
// Manual update the passes here.
67-
"infer_clean_graph_pass", //
68-
"attention_lstm_fuse_pass", //
69-
"fc_lstm_fuse_pass", //
70-
"mul_lstm_fuse_pass", //
71-
"fc_gru_fuse_pass", //
72-
"mul_gru_fuse_pass", //
73-
"seq_concat_fc_fuse_pass", //
74-
"fc_fuse_pass", //
67+
"infer_clean_graph_pass", //
68+
"attention_lstm_fuse_pass", //
69+
"embedding_fc_lstm_fuse_pass", //
70+
"fc_lstm_fuse_pass", //
71+
"mul_lstm_fuse_pass", //
72+
"fc_gru_fuse_pass", //
73+
"mul_gru_fuse_pass", //
74+
"seq_concat_fc_fuse_pass", //
75+
"fc_fuse_pass", //
7576
#ifdef PADDLE_WITH_MKLDNN
7677
"conv_relu_mkldnn_fuse_pass", //
7778
#endif

paddle/fluid/inference/api/paddle_inference_api.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ struct AnalysisConfig : public NativeConfig {
263263
bool enable_ir_optim = true;
264264
// Manually determine the IR passes to run.
265265
IrPassMode ir_mode{IrPassMode::kExclude};
266-
std::vector<std::string> ir_passes;
266+
std::vector<std::string> ir_passes{"embedding_fc_lstm_fuse_pass"};
267267

268268
// NOT stable yet.
269269
bool use_feed_fetch_ops{true};

paddle/fluid/inference/tests/api/analyzer_text_classification_tester.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,5 +104,18 @@ TEST(Analyzer_Text_Classification, compare) {
104104
CompareNativeAndAnalysis(cfg, input_slots_all);
105105
}
106106

107+
TEST(Analyzer_Text_Classification, compare_against_embedding_fc_lstm_fused) {
108+
AnalysisConfig cfg;
109+
SetConfig(&cfg);
110+
// Enable embedding_fc_lstm_fuse_pass (disabled by default)
111+
auto it = std::find(cfg.ir_passes.begin(), cfg.ir_passes.end(),
112+
"embedding_fc_lstm_fuse_pass");
113+
if (it != cfg.ir_passes.end()) cfg.ir_passes.erase(it);
114+
115+
std::vector<std::vector<PaddleTensor>> input_slots_all;
116+
SetInput(&input_slots_all);
117+
CompareNativeAndAnalysis(cfg, input_slots_all);
118+
}
119+
107120
} // namespace inference
108121
} // namespace paddle

0 commit comments

Comments
 (0)