Skip to content

Commit af15f6f

Browse files
authored
fea/refine fuse (#13076)
1 parent 819af27 commit af15f6f

20 files changed

+545
-226
lines changed

paddle/fluid/framework/ir/attention_lstm_fuse_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ void FindWhileOp(Graph* graph) {
5959

6060
auto handle = [&](const GraphPatternDetector::subgraph_t& subgraph,
6161
Graph* g) {
62-
auto* while_pat_node = gpd.pattern().RetriveNode("while");
62+
auto* while_pat_node = gpd.pattern().RetrieveNode("while");
6363
auto* while_node = subgraph.at(while_pat_node);
6464
marked_nodes.insert(while_node);
6565
};

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 37 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -31,77 +31,34 @@ bool VarOutLinksToOp(Node* node, const std::string& op_type) {
3131
}
3232

3333
void BuildFCPattern(PDPattern* pattern) {
34-
// make sure the selected MUL op has one input argument is a parameter.
35-
auto* mul_parameter_var = pattern->NewNode(
36-
[](Node* node) {
37-
return node->IsVar() && node->outputs.size() == 1UL &&
38-
node->outputs.front()->Op()->Type() == "mul" && node->Var() &&
39-
node->Var()->Persistable(); // check is a parameter
40-
},
41-
"mul_weight" /*name*/);
42-
43-
auto* mul_tmp_input_var = pattern->NewNode(
44-
[](Node* node) {
45-
bool result =
46-
node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
47-
!node->Var()->Persistable(); // this input is not an parameter.
48-
if (!result) return false;
49-
// check whether one output is MUL op.
50-
for (auto* op : node->outputs) {
51-
if (op->IsOp() && op->Op()->Type() == "mul") return true;
52-
}
53-
return false;
54-
},
55-
"mul_tmp_var" /*name*/);
56-
57-
// select a MUL op
58-
auto* mul_op = pattern->NewNode(
59-
[](Node* node) {
60-
return node->IsOp() && // start from an Op
61-
node->Op()->Type() == "mul"; // type is mul
62-
// the output should be consumed only by one element_add, that check
63-
// leaves in a Var PDNode.
64-
},
65-
"mul" /*name*/);
66-
67-
// make sure the MUL op's output has only one consumer and links to an
68-
// ELEMENTWISE_ADD op.
69-
auto* mul_out_var = pattern->NewNode(
70-
[](Node* node) {
71-
return node->IsVar() && // starts from a Var
72-
node->outputs.size() == 1UL && // only has one consumer
73-
node->outputs.front()->IsOp() && // check basic logic
74-
node->Var() && // not a ControlDepVar
75-
node->outputs.front()->Op()->Type() ==
76-
"elementwise_add"; // a very strong validation
77-
},
78-
"mul_out");
79-
// this check is not essential, just to make the corresponding variable Node
80-
// retrival easier.
81-
auto* elementwise_add_tmp_var = pattern->NewNode(
82-
[](Node* node) {
83-
return node->IsVar() && node->outputs.size() >= 1UL && node->Var() &&
84-
VarOutLinksToOp(node, "elementwise_add");
85-
},
86-
"elementwise_add_tmpvar");
87-
88-
// select an ELEMENTWISE_ADD op
89-
auto* elementwise_add_op = pattern->NewNode(
90-
[](Node* node) {
91-
return node->IsOp() && node->Op()->Type() == "elementwise_add";
92-
},
93-
"elementwise_add" /*name*/);
94-
95-
// get the ELEMENTWISE_ADD op's output
96-
auto* elementwise_add_out_var = pattern->NewNode(
97-
[](Node* node) {
98-
return node->IsVar() && node->inputs.size() == 1UL && node->Var() &&
99-
node->inputs.front()->Op()->Type() == "elementwise_add";
100-
},
101-
"elementwise_add_out");
102-
103-
mul_op->LinksFrom({mul_parameter_var, mul_tmp_input_var})
104-
.LinksTo({mul_out_var});
34+
// Create Operators
35+
auto* mul_op = pattern->NewNode("mul")->assert_is_op("mul");
36+
auto* elementwise_add_op =
37+
pattern->NewNode("elementwise_add")->assert_is_op("elementwise_add");
38+
// Create variables
39+
// w
40+
auto* mul_weight_var = pattern->NewNode("mul_weight")
41+
->AsInput()
42+
->assert_is_op_nth_input("mul", "Y", 0);
43+
// x
44+
auto* mul_tmp_var = pattern->NewNode("mul_tmp_var")
45+
->AsInput()
46+
->assert_is_op_nth_input("mul", "X", 0);
47+
// intermediate variable, will be removed in the IR after fuse.
48+
auto* mul_out_var = pattern->NewNode("mul_out")
49+
->AsIntermediate()
50+
->assert_is_only_output_of_op("mul")
51+
->assert_is_op_input("elementwise_add");
52+
// bias
53+
auto* elementwise_add_tmp_var = pattern->NewNode("elementwise_add_tmpvar")
54+
->assert_is_op_input("elementwise_add")
55+
->AsInput();
56+
// output
57+
auto* elementwise_add_out_var = pattern->NewNode("elementwise_add_out")
58+
->AsOutput()
59+
->assert_is_op_output("elementwise_add");
60+
61+
mul_op->LinksFrom({mul_weight_var, mul_tmp_var}).LinksTo({mul_out_var});
10562
elementwise_add_op->LinksFrom({mul_out_var, elementwise_add_tmp_var})
10663
.LinksTo({elementwise_add_out_var});
10764
}
@@ -120,18 +77,20 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
12077
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
12178
std::unique_ptr<ir::Graph> graph) const {
12279
PADDLE_ENFORCE(graph.get());
80+
FusePassBase::Init("fc", graph.get());
12381

12482
std::unordered_set<Node*> nodes2delete;
12583

12684
GraphPatternDetector gpd;
12785
BuildFCPattern(gpd.mutable_pattern());
12886

129-
#define GET_NODE(id) \
130-
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetriveNode(#id)), \
131-
"pattern has no Node called %s", #id); \
132-
auto* id = subgraph.at(gpd.pattern().RetriveNode(#id)); \
87+
#define GET_NODE(id) \
88+
PADDLE_ENFORCE(subgraph.count(gpd.pattern().RetrieveNode(#id)), \
89+
"pattern has no Node called %s", #id); \
90+
auto* id = subgraph.at(gpd.pattern().RetrieveNode(#id)); \
13391
PADDLE_ENFORCE_NOT_NULL(id, "subgraph has no node %s", #id);
13492

93+
int found_fc_count = 0;
13594
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
13695
Graph* g) {
13796
VLOG(4) << "handle FC fuse";
@@ -176,10 +135,13 @@ std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
176135
graph->RemoveNode(mul);
177136
graph->RemoveNode(elementwise_add);
178137
graph->RemoveNode(mul_out); // tmp variable
138+
139+
found_fc_count++;
179140
};
180141

181142
gpd(graph.get(), handler);
182143

144+
AddStatis(found_fc_count);
183145
return graph;
184146
}
185147

paddle/fluid/framework/ir/fc_fuse_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1516
#include "paddle/fluid/framework/ir/graph.h"
1617
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
1718
#include "paddle/fluid/framework/ir/pass.h"
@@ -23,7 +24,7 @@ namespace ir {
2324
/*
2425
* Fuse the MUL and ELEMENTWISE_ADD to a FCOp.
2526
*/
26-
class FCFusePass : public Pass {
27+
class FCFusePass : public FusePassBase {
2728
public:
2829
virtual ~FCFusePass() {}
2930

paddle/fluid/framework/ir/fc_fuse_pass_tester.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,13 @@ void SetOp(ProgramDesc* prog, const std::string& type,
2525
const std::vector<std::string>& outputs) {
2626
auto* op = prog->MutableBlock(0)->AppendOp();
2727
op->SetType(type);
28-
op->SetInput("Xs", inputs);
29-
op->SetOutput("Ys", outputs);
28+
if (type == "mul") {
29+
op->SetInput("X", {inputs[0]});
30+
op->SetInput("Y", {inputs[1]});
31+
} else if (type == "elementwise_add") {
32+
op->SetInput("X", inputs);
33+
}
34+
op->SetOutput("Out", outputs);
3035
}
3136

3237
// a->OP0->b

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
3636
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
3737
Graph* g) {
3838

39-
auto* id = subgraph.at(gpd.pattern().RetriveNode("any_node"));
39+
auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
4040
marked_nodes.insert(id);
4141
};
4242
gpd(graph.get(), handler);
@@ -64,9 +64,9 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
6464
#undef GET_NODE
6565
#undef SET_IN
6666

67-
LOG(INFO) << "hidden_n: " << hidden_n->Name();
68-
LOG(INFO) << "cell: " << cell_n->Name();
69-
LOG(INFO) << "xx: " << xx_n->Name();
67+
VLOG(4) << "hidden_n: " << hidden_n->Name();
68+
VLOG(4) << "cell: " << cell_n->Name();
69+
VLOG(4) << "xx: " << xx_n->Name();
7070

7171
op_desc.SetInput("H0", {});
7272
op_desc.SetInput("C0", {});

paddle/fluid/framework/ir/fuse_pass_base.h

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,37 @@ namespace paddle {
2222
namespace framework {
2323
namespace ir {
2424

25-
static const char kParamScopeAttr[] = "param_scope";
25+
static const char kParamScopeAttr[] = "__param_scope__";
26+
static const char kFuseStatisAttr[] = "__fuse_statis__";
2627

2728
class FusePassBase : public Pass {
2829
public:
29-
void Init(Graph* graph) const { graph_ = graph; }
30+
void Init(const std::string& repr, Graph* graph) const {
31+
repr_ = repr;
32+
graph_ = graph;
33+
}
3034

3135
Scope* param_scope() const {
3236
PADDLE_ENFORCE(graph_->Has(kParamScopeAttr));
3337
return graph_->Get<framework::Scope*>(kParamScopeAttr);
3438
}
3539

40+
void AddStatis(int count_of_fused) const {
41+
PADDLE_ENFORCE(graph_);
42+
PADDLE_ENFORCE(!repr_.empty());
43+
if (!graph_->Has(kFuseStatisAttr)) {
44+
graph_->Set(kFuseStatisAttr, new std::unordered_map<std::string, int>);
45+
}
46+
auto& info =
47+
graph_->Get<std::unordered_map<std::string, int>>(kFuseStatisAttr);
48+
info[repr_] = count_of_fused;
49+
}
50+
3651
virtual ~FusePassBase() {}
3752

3853
protected:
3954
mutable Graph* graph_;
55+
mutable std::string repr_;
4056
};
4157

4258
} // namespace ir

0 commit comments

Comments
 (0)