13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16
+ #include " paddle/fluid/framework/ir/fuse_pass_base.h"
17
+ #include " paddle/fluid/framework/lod_tensor.h"
16
18
17
19
namespace paddle {
18
20
namespace framework {
@@ -35,7 +37,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
35
37
36
38
auto handler = [&](const GraphPatternDetector::subgraph_t & subgraph,
37
39
Graph* g) {
38
-
39
40
auto * id = subgraph.at (gpd.pattern ().RetrieveNode (" any_node" ));
40
41
marked_nodes.insert (id);
41
42
};
@@ -73,12 +74,31 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
73
74
op_desc.SetOutput (" Hidden" , {hidden_n->Name ()});
74
75
op_desc.SetOutput (" Cell" , {cell_n->Name ()});
75
76
op_desc.SetOutput (" XX" , {xx_n->Name ()});
76
- op_desc.SetOutput (" BatchedGate" , {" blstm_0.tmp_2" });
77
- op_desc.SetOutput (" BatchCellPreAct" , {" blstm_1.tmp_2" });
77
+ op_desc.SetOutput (" BatchedInput" , {" blstm_0.tmp_2" });
78
78
op_desc.SetAttr (" is_reverse" , lstm_n->Op ()->GetAttr (" is_reverse" ));
79
79
op_desc.SetAttr (" use_peepholes" , false );
80
+
81
+ #define TMP_NAME (x ) " at.new.tmp." #x
82
+ #define OP_SET_OUT (x ) op_desc.SetOutput(#x, {TMP_NAME (x)})
83
+ OP_SET_OUT (BatchedCell);
84
+ OP_SET_OUT (BatchedHidden);
85
+ OP_SET_OUT (ReorderedH0);
86
+ OP_SET_OUT (ReorderedC0);
87
+ #undef OP_SET_OUT
80
88
auto * op = graph->CreateOpNode (&op_desc);
81
89
90
+ PADDLE_ENFORCE (graph->Has (kParamScopeAttr ));
91
+ auto * scope = graph->Get <Scope*>(kParamScopeAttr );
92
+
93
+ #define TMP_NEW (x ) scope->Var (TMP_NAME(x))->GetMutable<LoDTensor>()
94
+ TMP_NEW(BatchedCell);
95
+ TMP_NEW (BatchedHidden);
96
+ TMP_NEW (ReorderedH0);
97
+ TMP_NEW (ReorderedC0);
98
+
99
+ #undef TMP_NEW
100
+ #undef TMP_NAME
101
+
82
102
#define LINK_TO (a, b ) \
83
103
a->outputs .push_back (b); \
84
104
b->inputs .push_back (a);
@@ -89,7 +109,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
89
109
LINK_TO (op, hidden_n);
90
110
#undef LINK_TO
91
111
return op;
92
-
93
112
};
94
113
95
114
lstm_creator (16 , 12 , 14 , 18 , 17 , 22 , 21 , 19 );
@@ -105,14 +124,16 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
105
124
for (auto it = node->inputs .begin (); it != node->inputs .end ();) {
106
125
if (marked_nodes.count (*it)) {
107
126
it = const_cast <Node*>(node)->inputs .erase (it);
108
- } else
127
+ } else {
109
128
it++;
129
+ }
110
130
}
111
131
for (auto it = node->outputs .begin (); it != node->outputs .end ();) {
112
132
if (marked_nodes.count (*it)) {
113
133
it = const_cast <Node*>(node)->outputs .erase (it);
114
- } else
134
+ } else {
115
135
it++;
136
+ }
116
137
}
117
138
}
118
139
0 commit comments