13
13
// limitations under the License.
14
14
15
15
#include " paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16
+ #include < string>
16
17
#include " paddle/fluid/framework/lod_tensor.h"
17
18
18
19
namespace paddle {
@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
94
95
op_desc.SetOutput (" Hidden" , {hidden_n->Name ()});
95
96
op_desc.SetOutput (" Cell" , {cell_n->Name ()});
96
97
op_desc.SetOutput (" XX" , {xx_n->Name ()});
97
- op_desc.SetOutput (" BatchedGate" , {" blstm_0.tmp_2" });
98
- op_desc.SetOutput (" BatchCellPreAct" , {" blstm_1.tmp_2" });
98
+ op_desc.SetOutput (" BatchedInput" , {" blstm_0.tmp_2" });
99
99
op_desc.SetAttr (" is_reverse" , lstm_n->Op ()->GetAttr (" is_reverse" ));
100
100
op_desc.SetAttr (" use_peepholes" , lstm_n->Op ()->GetAttr (" use_peepholes" ));
101
+ // TODO(TJ): get from attr
102
+ op_desc.SetAttr (" use_seq" , true );
103
+
104
+ #define TMP_NAME (x ) " at.new.tmp." #x
105
+ #define OP_SET_OUT (x ) op_desc.SetOutput(#x, {TMP_NAME (x)})
106
+ OP_SET_OUT (BatchedCell);
107
+ OP_SET_OUT (BatchedHidden);
108
+ OP_SET_OUT (ReorderedH0);
109
+ OP_SET_OUT (ReorderedC0);
110
+ #undef OP_SET_OUT
111
+
101
112
auto * op = graph->CreateOpNode (&op_desc);
113
+ PADDLE_ENFORCE (graph->Has (kParamScopeAttr ));
114
+ auto * scope = graph->Get <Scope*>(kParamScopeAttr );
115
+
116
+ #define TMP_NEW (x ) scope->Var (TMP_NAME(x))->GetMutable<LoDTensor>()
117
+ TMP_NEW(BatchedCell);
118
+ TMP_NEW (BatchedHidden);
119
+ TMP_NEW (ReorderedH0);
120
+ TMP_NEW (ReorderedC0);
121
+ #undef TMP_NEW
122
+ #undef TMP_NAME
102
123
103
124
#define LINK_TO (a, b ) \
104
125
a->outputs .push_back (b); \
@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
116
137
117
138
auto fc_no_bias_handler = [&](
118
139
const GraphPatternDetector::subgraph_t & subgraph, Graph* g) {
119
-
120
140
#define GET_NODE (name__ ) \
121
141
std::string name__##key = name_scope + " /" + #name__; \
122
142
auto * name__##n = pattern->RetrieveNode (name__##key); \
0 commit comments