@@ -20,52 +20,43 @@ namespace paddle {
20
20
namespace framework {
21
21
namespace ir {
22
22
23
- static void BuildPattern (PDPattern* pattern, const std::string& name_scope,
24
- bool with_fc_bias) {
25
- PDNode* x = pattern->NewNode (name_scope, " x" )
26
- ->assert_is_op_input (" mul" )
27
- ->assert_var_not_persistable ();
28
- auto * fc_out = patterns::FC (pattern, name_scope, x, with_fc_bias);
29
- fc_out->AsIntermediate (); // fc_out is a tmp var, will be removed after fuse.
30
- patterns::GRU (pattern, name_scope, fc_out);
31
- VLOG (3 ) << " fc_gru pattern \n " << pattern->DotString ();
32
- }
33
-
34
23
static int BuildFusion (Graph* graph, const std::string& name_scope,
35
24
Scope* scope, bool with_fc_bias) {
36
25
GraphPatternDetector gpd;
37
26
auto * pattern = gpd.mutable_pattern ();
38
27
39
- BuildPattern (pattern, name_scope, with_fc_bias);
28
+ // Create pattern.
29
+ patterns::FC fc_pattern (pattern, name_scope);
30
+ patterns::GRU gru_pattern (pattern, name_scope);
31
+
32
+ PDNode* x =
33
+ pattern->NewNode (patterns::UniqueKey (" x" ))->assert_var_not_persistable ();
34
+
35
+ auto * fc_out = fc_pattern (x, with_fc_bias);
36
+ fc_out->AsIntermediate (); // fc_out is a tmp var, will be removed after fuse.
37
+ gru_pattern (fc_out);
40
38
41
39
// Create New OpDesc
42
- auto gru_creater = [&](int gru, int x, int weight_x, int weight_h, int bias,
43
- int hidden, int fc_bias) {
44
- #define GET_NODE (x ) auto * x##_n = graph->RetriveNode (x);
45
- GET_NODE (x);
46
- GET_NODE (weight_x);
47
- GET_NODE (weight_h);
48
- GET_NODE (bias);
49
- GET_NODE (hidden);
50
- GET_NODE (gru);
40
+ auto gru_creater = [&](Node* gru, Node* x, Node* weight_x, Node* weight_h,
41
+ Node* bias, Node* hidden, Node* fc_bias) {
51
42
52
43
OpDesc op_desc;
53
44
op_desc.SetType (" fusion_gru" );
54
45
55
46
#define NEW_NAME (x ) name_scope + " /at." #x " .new"
56
- #define SET_IN (Key, node__ ) op_desc.SetInput(#Key, {node__##_n ->Name ()});
47
+ #define SET_IN (Key, node__ ) op_desc.SetInput(#Key, {node__->Name ()});
57
48
SET_IN (X, x);
58
49
SET_IN (WeightX, weight_x);
59
50
SET_IN (WeightH, weight_h);
60
51
if (with_fc_bias) {
61
- op_desc.SetInput (" Bias" , {NEW_NAME (bias) + bias_n ->Name ()});
52
+ op_desc.SetInput (" Bias" , {NEW_NAME (bias) + bias ->Name ()});
62
53
} else {
63
54
SET_IN (Bias, bias);
64
55
}
65
56
#undef SET_IN
66
57
op_desc.SetInput (" H0" , {});
67
- op_desc.SetOutput (" Hidden" , {hidden_n ->Name ()});
68
- op_desc.SetAttr (" is_reverse" , gru_n ->Op ()->GetAttr (" is_reverse" ));
58
+ op_desc.SetOutput (" Hidden" , {hidden ->Name ()});
59
+ op_desc.SetAttr (" is_reverse" , gru ->Op ()->GetAttr (" is_reverse" ));
69
60
// TODO(TJ): This should be a option for infer
70
61
op_desc.SetAttr (" use_seq" , true );
71
62
@@ -82,14 +73,12 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
82
73
PADDLE_ENFORCE (scope);
83
74
if (with_fc_bias) {
84
75
// Fusion GRU bias = fcbias + grubias
85
- auto * fusion_bias_var = scope->Var (NEW_NAME (bias) + bias_n ->Name ());
76
+ auto * fusion_bias_var = scope->Var (NEW_NAME (bias) + bias ->Name ());
86
77
auto * out_bias_tensor =
87
78
fusion_bias_var->GetMutable <framework::LoDTensor>();
88
79
PADDLE_ENFORCE (fusion_bias_var);
89
- GET_NODE (fc_bias);
90
- PADDLE_ENFORCE (fc_bias_n);
91
- auto * gru_bias_var = scope->FindVar (bias_n->Name ());
92
- auto * fc_bias_var = scope->FindVar (fc_bias_n->Name ());
80
+ auto * gru_bias_var = scope->FindVar (bias->Name ());
81
+ auto * fc_bias_var = scope->FindVar (fc_bias->Name ());
93
82
PADDLE_ENFORCE (gru_bias_var);
94
83
PADDLE_ENFORCE (fc_bias_var);
95
84
const auto & gru_bias_tenosr = gru_bias_var->Get <framework::LoDTensor>();
@@ -113,54 +102,47 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
113
102
#undef NEW_NAME
114
103
#undef NEW_IMTERMEDIATE_OUT
115
104
116
- IR_NODE_LINK_TO (x_n , op);
117
- IR_NODE_LINK_TO (weight_x_n , op);
118
- IR_NODE_LINK_TO (weight_h_n , op);
119
- IR_NODE_LINK_TO (bias_n , op); // actually should link to new bias if have
120
- IR_NODE_LINK_TO (op, hidden_n );
105
+ IR_NODE_LINK_TO (x , op);
106
+ IR_NODE_LINK_TO (weight_x , op);
107
+ IR_NODE_LINK_TO (weight_h , op);
108
+ IR_NODE_LINK_TO (bias , op); // actually should link to new bias if have
109
+ IR_NODE_LINK_TO (op, hidden );
121
110
// h0?
122
111
return op;
123
112
};
124
113
125
114
int fusion_count{0 };
126
115
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
127
116
Graph* g) {
128
- #define GET_NODE (name__ ) \
129
- std::string name__##key = name_scope + " /" + #name__; \
130
- auto * name__##n = pattern->RetrieveNode (name__##key); \
131
- PADDLE_ENFORCE (name__##n); \
132
- PADDLE_ENFORCE (subgraph.count (name__##n)); \
133
- Node* name__##_n = subgraph.at (name__##n); \
134
- int name__ __attribute__ ((unused)) = name__##_n->id ();
135
-
136
- GET_NODE (x);
137
- GET_NODE (w); // fc weight
138
- GET_NODE (mul);
139
- GET_NODE (fc_out);
140
- GET_NODE (Weight);
141
- GET_NODE (gru);
142
- GET_NODE (Bias);
143
- GET_NODE (Hidden);
117
+ auto * x_n = subgraph.at (x);
118
+ GET_IR_NODE_FROM_SUBGRAPH (w, w, fc_pattern);
119
+ GET_IR_NODE_FROM_SUBGRAPH (mul, mul, fc_pattern);
120
+ GET_IR_NODE_FROM_SUBGRAPH (fc_out, Out, fc_pattern);
121
+ GET_IR_NODE_FROM_SUBGRAPH (Weight, Weight, gru_pattern);
122
+ GET_IR_NODE_FROM_SUBGRAPH (gru, gru, gru_pattern);
123
+ GET_IR_NODE_FROM_SUBGRAPH (Bias, Bias, gru_pattern);
124
+ GET_IR_NODE_FROM_SUBGRAPH (Hidden, Hidden, gru_pattern);
144
125
// nodes need be removed
145
- GET_NODE (BatchGate);
146
- GET_NODE (BatchResetHiddenPrev);
147
- GET_NODE (BatchHidden);
126
+ GET_IR_NODE_FROM_SUBGRAPH (BatchGate, BatchGate, gru_pattern );
127
+ GET_IR_NODE_FROM_SUBGRAPH (BatchResetHiddenPrev, BatchGate, gru_pattern );
128
+ GET_IR_NODE_FROM_SUBGRAPH (BatchHidden, BatchGate, gru_pattern );
148
129
149
130
if (with_fc_bias) {
150
- GET_NODE (mul_out);
151
- GET_NODE (fc_bias);
152
- GET_NODE (elementwise_add);
153
- gru_creater (gru, x, w, Weight, Bias, Hidden, fc_bias);
131
+ GET_IR_NODE_FROM_SUBGRAPH (mul_out, mul_out, fc_pattern);
132
+ GET_IR_NODE_FROM_SUBGRAPH (fc_bias, bias, fc_pattern);
133
+ GET_IR_NODE_FROM_SUBGRAPH (elementwise_add, elementwise_add, fc_pattern);
134
+
135
+ gru_creater (gru, x_n, w, Weight, Bias, Hidden, fc_bias);
154
136
// Remove unneeded nodes.
155
137
std::unordered_set<const Node*> marked_nodes (
156
- {mul_n, gru_n, elementwise_add_n, fc_bias_n, fc_out_n, mul_out_n ,
157
- BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n });
138
+ {mul, gru, elementwise_add, fc_bias, fc_out, mul_out, BatchGate ,
139
+ BatchResetHiddenPrev, BatchHidden });
158
140
GraphSafeRemoveNodes (graph, marked_nodes);
159
141
} else {
160
- gru_creater (gru, x , w, Weight, Bias, Hidden, - 1 );
142
+ gru_creater (gru, x_n , w, Weight, Bias, Hidden, nullptr );
161
143
// Remove unneeded nodes.
162
144
std::unordered_set<const Node*> marked_nodes (
163
- {mul_n, gru_n, BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n });
145
+ {mul, gru, BatchGate, BatchResetHiddenPrev, BatchHidden });
164
146
GraphSafeRemoveNodes (graph, marked_nodes);
165
147
}
166
148
#undef GET_NODE
0 commit comments