Skip to content

Commit 1cc35f3

Browse files
authored
Merge pull request #13118 from tensor-tang/optimize/op/fusion_lstm
Optimize fusion lstm batch mode
2 parents 6fb2879 + 93c034e commit 1cc35f3

File tree

2 files changed

+221
-185
lines changed

2 files changed

+221
-185
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16+
#include <string>
1617
#include "paddle/fluid/framework/lod_tensor.h"
1718

1819
namespace paddle {
@@ -94,11 +95,31 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
9495
op_desc.SetOutput("Hidden", {hidden_n->Name()});
9596
op_desc.SetOutput("Cell", {cell_n->Name()});
9697
op_desc.SetOutput("XX", {xx_n->Name()});
97-
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
98-
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
98+
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
9999
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
100100
op_desc.SetAttr("use_peepholes", lstm_n->Op()->GetAttr("use_peepholes"));
101+
// TODO(TJ): get from attr
102+
op_desc.SetAttr("use_seq", true);
103+
104+
#define TMP_NAME(x) "at.new.tmp." #x
105+
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
106+
OP_SET_OUT(BatchedCell);
107+
OP_SET_OUT(BatchedHidden);
108+
OP_SET_OUT(ReorderedH0);
109+
OP_SET_OUT(ReorderedC0);
110+
#undef OP_SET_OUT
111+
101112
auto* op = graph->CreateOpNode(&op_desc);
113+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
114+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
115+
116+
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
117+
TMP_NEW(BatchedCell);
118+
TMP_NEW(BatchedHidden);
119+
TMP_NEW(ReorderedH0);
120+
TMP_NEW(ReorderedC0);
121+
#undef TMP_NEW
122+
#undef TMP_NAME
102123

103124
#define LINK_TO(a, b) \
104125
a->outputs.push_back(b); \
@@ -116,7 +137,6 @@ int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope,
116137

117138
auto fc_no_bias_handler = [&](
118139
const GraphPatternDetector::subgraph_t& subgraph, Graph* g) {
119-
120140
#define GET_NODE(name__) \
121141
std::string name__##key = name_scope + "/" + #name__; \
122142
auto* name__##n = pattern->RetrieveNode(name__##key); \

0 commit comments

Comments
 (0)