@@ -28,7 +28,7 @@ static void BuildPattern(PDPattern* pattern, const std::string& name_scope,
28
28
auto * fc_out = patterns::FC (pattern, name_scope, x, with_fc_bias);
29
29
fc_out->AsIntermediate (); // fc_out is a tmp var, will be removed after fuse.
30
30
patterns::GRU (pattern, name_scope, fc_out);
31
- VLOG (3 ) << " \n " << pattern->DotString ();
31
+ VLOG (3 ) << " fc_gru pattern \n " << pattern->DotString ();
32
32
}
33
33
34
34
static int BuildFusion (Graph* graph, const std::string& name_scope,
@@ -51,65 +51,72 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
51
51
52
52
OpDesc op_desc;
53
53
op_desc.SetType (" fusion_gru" );
54
+
55
+ #define NEW_NAME (x ) name_scope + " /at." #x " .new"
54
56
#define SET_IN (Key, node__ ) op_desc.SetInput(#Key, {node__##_n->Name ()});
55
57
SET_IN (X, x);
56
58
SET_IN (WeightX, weight_x);
57
59
SET_IN (WeightH, weight_h);
58
- SET_IN (Bias, bias);
60
+ if (with_fc_bias) {
61
+ op_desc.SetInput (" Bias" , {NEW_NAME (bias) + bias_n->Name ()});
62
+ } else {
63
+ SET_IN (Bias, bias);
64
+ }
59
65
#undef SET_IN
66
+ op_desc.SetInput (" H0" , {});
67
+ op_desc.SetOutput (" Hidden" , {hidden_n->Name ()});
68
+ op_desc.SetAttr (" is_reverse" , gru_n->Op ()->GetAttr (" is_reverse" ));
69
+ // TODO(TJ): This should be a option for infer
70
+ op_desc.SetAttr (" use_seq" , true );
71
+
72
+ #define SET_IMTERMEDIATE_OUT (key ) op_desc.SetOutput(#key, {NEW_NAME (key)})
73
+ SET_IMTERMEDIATE_OUT (ReorderedH0);
74
+ SET_IMTERMEDIATE_OUT (XX);
75
+ SET_IMTERMEDIATE_OUT (BatchedInput);
76
+ SET_IMTERMEDIATE_OUT (BatchedOut);
77
+ #undef SET_IMTERMEDIATE_OUT
78
+
79
+ auto * op = graph->CreateOpNode (&op_desc);
80
+ PADDLE_ENFORCE (graph->Has (kParamScopeAttr ));
81
+ auto * scope = graph->Get <Scope*>(kParamScopeAttr );
82
+ PADDLE_ENFORCE (scope);
60
83
if (with_fc_bias) {
61
- // Add FC-bias with LSTM-bias and create a new weight
62
- PADDLE_ENFORCE (scope);
63
- const std::string& new_bias_var = name_scope + " _bias.new" ;
64
- auto * bias_var = scope->Var (new_bias_var);
65
- PADDLE_ENFORCE (bias_var);
66
- auto * bias_tensor = bias_var->GetMutable <framework::LoDTensor>();
84
+ // Fusion GRU bias = fcbias + grubias
85
+ auto * fusion_bias_var = scope->Var (NEW_NAME (bias) + bias_n->Name ());
86
+ auto * out_bias_tensor =
87
+ fusion_bias_var->GetMutable <framework::LoDTensor>();
88
+ PADDLE_ENFORCE (fusion_bias_var);
89
+ GET_NODE (fc_bias);
90
+ PADDLE_ENFORCE (fc_bias_n);
67
91
auto * gru_bias_var = scope->FindVar (bias_n->Name ());
92
+ auto * fc_bias_var = scope->FindVar (fc_bias_n->Name ());
68
93
PADDLE_ENFORCE (gru_bias_var);
94
+ PADDLE_ENFORCE (fc_bias_var);
69
95
const auto & gru_bias_tenosr = gru_bias_var->Get <framework::LoDTensor>();
70
- bias_tensor->Resize (gru_bias_tenosr.dims ());
71
-
72
- GET_NODE (fc_bias);
73
- auto * fc_bias_var = scope->FindVar (fc_bias_n->Name ());
74
96
const auto & fc_bias_tensor = fc_bias_var->Get <framework::LoDTensor>();
75
97
// new bias = fc bias + gru bias
76
- auto * data = bias_tensor->mutable_data <float >(platform::CPUPlace ());
77
- for (int i = 0 ; i < bias_tensor->numel (); i++) {
98
+ out_bias_tensor->Resize (gru_bias_tenosr.dims ());
99
+ auto * data = out_bias_tensor->mutable_data <float >(platform::CPUPlace ());
100
+ for (int i = 0 ; i < out_bias_tensor->numel (); i++) {
78
101
data[i] =
79
102
fc_bias_tensor.data <float >()[i] + gru_bias_tenosr.data <float >()[i];
80
103
}
81
- op_desc.SetInput (" Bias" , {new_bias_var});
82
104
}
83
105
#undef GET_NODE
84
106
85
- op_desc.SetInput (" H0" , {});
86
- op_desc.SetOutput (" Hidden" , {hidden_n->Name ()});
87
- op_desc.SetAttr (" is_reverse" , gru_n->Op ()->GetAttr (" is_reverse" ));
88
- // TODO(TJ): This should be a option for infer
89
- op_desc.SetAttr (" use_seq" , true );
90
-
91
- // Create temp variables.
92
- // TODO(TJ): clean code
93
- scope->Var (name_scope + " /ReorderedH0.new" )
94
- ->GetMutable <framework::LoDTensor>();
95
- scope->Var (name_scope + " /XX.new" )->GetMutable <framework::LoDTensor>();
96
- scope->Var (name_scope + " /BatchedInput.new" )
97
- ->GetMutable <framework::LoDTensor>();
98
- scope->Var (name_scope + " /BatchedOut.new" )
99
- ->GetMutable <framework::LoDTensor>();
100
- op_desc.SetOutput (" ReorderedH0" , {name_scope + " /ReorderedH0.new" });
101
- op_desc.SetOutput (" XX" , {name_scope + " /XX.new" });
102
- op_desc.SetOutput (" BatchedInput" , {name_scope + " /BatchedInput.new" });
103
- op_desc.SetOutput (" BatchedOut" , {name_scope + " /BatchedOut.new" });
104
-
105
- auto * op = graph->CreateOpNode (&op_desc);
106
- PADDLE_ENFORCE (graph->Has (kParamScopeAttr ));
107
- // auto* scope = graph->Get<Scope*>(kParamScopeAttr);
107
+ #define NEW_IMTERMEDIATE_OUT (key ) \
108
+ scope->Var (NEW_NAME (key))->GetMutable <framework::LoDTensor>()
109
+ NEW_IMTERMEDIATE_OUT (ReorderedH0);
110
+ NEW_IMTERMEDIATE_OUT (XX);
111
+ NEW_IMTERMEDIATE_OUT (BatchedInput);
112
+ NEW_IMTERMEDIATE_OUT (BatchedOut);
113
+ #undef NEW_NAME
114
+ #undef NEW_IMTERMEDIATE_OUT
108
115
109
116
IR_NODE_LINK_TO (x_n, op);
110
117
IR_NODE_LINK_TO (weight_x_n, op);
111
118
IR_NODE_LINK_TO (weight_h_n, op);
112
- IR_NODE_LINK_TO (bias_n, op);
119
+ IR_NODE_LINK_TO (bias_n, op); // actually should link to new bias if have
113
120
IR_NODE_LINK_TO (op, hidden_n);
114
121
// h0?
115
122
return op;
@@ -127,26 +134,33 @@ static int BuildFusion(Graph* graph, const std::string& name_scope,
127
134
int name__ __attribute__ ((unused)) = name__##_n->id ();
128
135
129
136
GET_NODE (x);
130
- GET_NODE (w);
137
+ GET_NODE (w); // fc weight
131
138
GET_NODE (mul);
132
139
GET_NODE (fc_out);
133
140
GET_NODE (Weight);
134
141
GET_NODE (gru);
135
142
GET_NODE (Bias);
136
143
GET_NODE (Hidden);
144
+ // nodes need be removed
145
+ GET_NODE (BatchGate);
146
+ GET_NODE (BatchResetHiddenPrev);
147
+ GET_NODE (BatchHidden);
137
148
138
149
if (with_fc_bias) {
150
+ GET_NODE (mul_out);
139
151
GET_NODE (fc_bias);
140
152
GET_NODE (elementwise_add);
141
153
gru_creater (gru, x, w, Weight, Bias, Hidden, fc_bias);
142
154
// Remove unneeded nodes.
143
155
std::unordered_set<const Node*> marked_nodes (
144
- {mul_n, gru_n, elementwise_add_n});
156
+ {mul_n, gru_n, elementwise_add_n, fc_bias_n, fc_out_n, mul_out_n,
157
+ BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n});
145
158
GraphSafeRemoveNodes (graph, marked_nodes);
146
159
} else {
147
160
gru_creater (gru, x, w, Weight, Bias, Hidden, -1 );
148
161
// Remove unneeded nodes.
149
- std::unordered_set<const Node*> marked_nodes ({mul_n, gru_n});
162
+ std::unordered_set<const Node*> marked_nodes (
163
+ {mul_n, gru_n, BatchGate_n, BatchResetHiddenPrev_n, BatchHidden_n});
150
164
GraphSafeRemoveNodes (graph, marked_nodes);
151
165
}
152
166
#undef GET_NODE
0 commit comments