Skip to content

Commit 597b730

Browse files
authored
refine/fc lstm fusion link (#13158)
1 parent 1e7ccf9 commit 597b730

File tree

7 files changed

+312
-106
lines changed

7 files changed

+312
-106
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 105 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -13,37 +13,37 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16+
#include "paddle/fluid/framework/lod_tensor.h"
1617

1718
namespace paddle {
1819
namespace framework {
1920
namespace ir {
2021

21-
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
22-
std::unique_ptr<ir::Graph> graph) const {
23-
GraphPatternDetector gpd;
24-
auto* pattern = gpd.mutable_pattern();
25-
26-
std::unordered_set<int> fused_ops({// first lstm
27-
13, 15, 16,
28-
// second lstm
29-
23, 25, 26});
30-
31-
pattern->NewNode([&](Node* x) { return fused_ops.count(x->id()); },
32-
"any_node");
22+
std::string GenNodeName(const std::string& prefix, const std::string& name) {
23+
return prefix + "/" + name;
24+
}
3325

34-
std::unordered_set<Node*> marked_nodes;
26+
void BuildPattern(PDPattern* pattern, const std::string& name_scope,
27+
bool with_fc_bias) {
28+
PDNode* x = pattern->NewNode(name_scope, "x")
29+
->assert_is_op_input("mul")
30+
->assert_var_not_persistable();
31+
auto* fc_out = patterns::FC(pattern, name_scope, x, with_fc_bias);
32+
fc_out->AsIntermediate(); // fc_out is a tmp var, will be removed after fuse.
33+
patterns::LSTM(pattern, name_scope, fc_out);
34+
// LOG(INFO) << "\n" << pattern->DotString();
35+
}
3536

36-
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
37-
Graph* g) {
37+
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
38+
bool with_fc_bias) {
39+
GraphPatternDetector gpd;
40+
auto* pattern = gpd.mutable_pattern();
3841

39-
auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
40-
marked_nodes.insert(id);
41-
};
42-
gpd(graph.get(), handler);
42+
BuildPattern(pattern, name_scope, with_fc_bias);
4343

4444
// Create New OpDesc
4545
auto lstm_creator = [&](int lstm, int input, int weight_x, int weight_h,
46-
int bias, int hidden, int cell, int xx) {
46+
int bias, int hidden, int cell, int xx, int fc_bias) {
4747
#define GET_NODE(x) auto* x##_n = graph->RetriveNode(x);
4848
GET_NODE(input);
4949
GET_NODE(weight_x);
@@ -61,12 +61,33 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
6161
SET_IN(WeightX, weight_x);
6262
SET_IN(WeightH, weight_h);
6363
SET_IN(Bias, bias);
64-
#undef GET_NODE
6564
#undef SET_IN
65+
if (with_fc_bias) {
66+
// Add FC-bias with LSTM-bias and create a new weight
67+
PADDLE_ENFORCE(scope);
68+
const std::string& new_bias_var = name_scope + "_bias.new";
69+
auto* bias_var = scope->Var(new_bias_var);
70+
PADDLE_ENFORCE(bias_var);
71+
auto* bias_tensor = bias_var->GetMutable<framework::LoDTensor>();
72+
auto* lstm_bias_var = scope->FindVar(bias_n->Name());
73+
PADDLE_ENFORCE(lstm_bias_var);
74+
const auto& lstm_bias_tensor = lstm_bias_var->Get<framework::LoDTensor>();
75+
bias_tensor->Resize(lstm_bias_tensor.dims());
76+
77+
GET_NODE(fc_bias);
78+
auto* fc_bias_var = scope->FindVar(fc_bias_n->Name());
79+
const auto& fc_bias_tensor = fc_bias_var->Get<framework::LoDTensor>();
80+
81+
auto* data = bias_tensor->mutable_data<float>(platform::CPUPlace());
82+
83+
for (int i = 0; i < bias_tensor->numel(); i++) {
84+
data[i] =
85+
fc_bias_tensor.data<float>()[i] + lstm_bias_tensor.data<float>()[i];
86+
}
87+
op_desc.SetInput("Bias", {new_bias_var});
88+
}
6689

67-
VLOG(4) << "hidden_n: " << hidden_n->Name();
68-
VLOG(4) << "cell: " << cell_n->Name();
69-
VLOG(4) << "xx: " << xx_n->Name();
90+
#undef GET_NODE
7091

7192
op_desc.SetInput("H0", {});
7293
op_desc.SetInput("C0", {});
@@ -76,7 +97,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
7697
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
7798
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
7899
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
79-
op_desc.SetAttr("use_peepholes", false);
100+
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
80101
auto* op = graph->CreateOpNode(&op_desc);
81102

82103
#define LINK_TO(a, b) \
@@ -89,38 +110,77 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
89110
LINK_TO(op, hidden_n);
90111
#undef LINK_TO
91112
return op;
92-
93113
};
94114

95-
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
96-
lstm_creator(26, 12, 24, 28, 27, 32, 31, 29);
115+
int fusion_count{0};
97116

98-
// remove all the nodes
117+
auto fc_no_bias_handler = [&](
118+
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
99119

100-
for (auto* node : marked_nodes) {
101-
graph->RemoveNode(const_cast<Node*>(node));
102-
}
120+
#define GET_NODE(name__) \
121+
std::string name__##key = name_scope + "/" + #name__; \
122+
auto* name__##n = pattern->RetrieveNode(name__##key); \
123+
PADDLE_ENFORCE(name__##n); \
124+
PADDLE_ENFORCE(subgraph.count(name__##n)); \
125+
Node* name__##_n = subgraph.at(name__##n); \
126+
int name__ __attribute__((unused)) = name__##_n->id();
103127

104-
for (auto* node : graph->Nodes()) {
105-
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
106-
if (marked_nodes.count(*it)) {
107-
it = const_cast<Node*>(node)->inputs.erase(it);
108-
} else
109-
it++;
110-
}
111-
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
112-
if (marked_nodes.count(*it)) {
113-
it = const_cast<Node*>(node)->outputs.erase(it);
114-
} else
115-
it++;
128+
GET_NODE(x);
129+
GET_NODE(w);
130+
GET_NODE(mul);
131+
GET_NODE(fc_out);
132+
GET_NODE(Weight);
133+
GET_NODE(lstm);
134+
GET_NODE(Bias);
135+
GET_NODE(Hidden);
136+
GET_NODE(Cell);
137+
138+
if (with_fc_bias) {
139+
GET_NODE(fc_bias);
140+
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, fc_bias);
141+
} else {
142+
lstm_creator(lstm, x, w, Weight, Bias, Hidden, Cell, fc_out, -1);
116143
}
117-
}
144+
#undef GET_NODE
145+
146+
// Remove unneeded nodes.
147+
std::unordered_set<const Node*> marked_nodes({mul_n, lstm_n});
148+
149+
GraphSafeRemoveNodes(graph, marked_nodes);
150+
151+
++fusion_count;
152+
};
153+
154+
gpd(graph, fc_no_bias_handler);
155+
156+
return fusion_count;
157+
}
158+
159+
std::unique_ptr<ir::Graph> MulLstmFusePass::ApplyImpl(
160+
std::unique_ptr<ir::Graph> graph) const {
161+
FusePassBase::Init(name_scope_, graph.get());
162+
163+
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
164+
false /*with_fc_bias*/);
165+
166+
AddStatis(fusion_count);
167+
return graph;
168+
}
169+
170+
std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
171+
std::unique_ptr<ir::Graph> graph) const {
172+
FusePassBase::Init(name_scope_, graph.get());
173+
174+
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope(),
175+
true /*with_fc_bias*/);
118176

177+
AddStatis(fusion_count);
119178
return graph;
120179
}
121180

122181
} // namespace ir
123182
} // namespace framework
124183
} // namespace paddle
125184

185+
REGISTER_PASS(mul_lstm_fuse_pass, paddle::framework::ir::MulLstmFusePass);
126186
REGISTER_PASS(fc_lstm_fuse_pass, paddle::framework::ir::FCLstmFusePass);

paddle/fluid/framework/ir/fc_lstm_fuse_pass.h

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,34 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1516
#include "paddle/fluid/framework/ir/graph.h"
1617
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
17-
#include "paddle/fluid/framework/ir/pass.h"
1818

1919
namespace paddle {
2020
namespace framework {
2121
namespace ir {
2222

23-
class FCLstmFusePass : public Pass {
23+
// The MulLstmFusePass and MulLstmFusePass will fuse to the same FusionLstm op.
24+
25+
// Just FC without bias
26+
class FCLstmFusePass : public FusePassBase {
2427
public:
2528
virtual ~FCLstmFusePass() {}
2629

2730
protected:
2831
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
33+
const std::string name_scope_{"fc_lstm_fuse"};
34+
};
35+
36+
class MulLstmFusePass : public FusePassBase {
37+
public:
38+
virtual ~MulLstmFusePass() {}
39+
40+
protected:
41+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
42+
const std::string name_scope_{"fc_nobias_lstm_fuse"};
2943
};
3044

3145
} // namespace ir

0 commit comments

Comments
 (0)