Skip to content

Commit 80edd7e

Browse files
committed
enable run with fuse pass
1 parent a79a77e commit 80edd7e

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

paddle/fluid/framework/ir/fc_lstm_fuse_pass.cc

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/ir/fc_lstm_fuse_pass.h"
16+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
17+
#include "paddle/fluid/framework/lod_tensor.h"
1618

1719
namespace paddle {
1820
namespace framework {
@@ -35,7 +37,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
3537

3638
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
3739
Graph* g) {
38-
3940
auto* id = subgraph.at(gpd.pattern().RetrieveNode("any_node"));
4041
marked_nodes.insert(id);
4142
};
@@ -73,12 +74,31 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
7374
op_desc.SetOutput("Hidden", {hidden_n->Name()});
7475
op_desc.SetOutput("Cell", {cell_n->Name()});
7576
op_desc.SetOutput("XX", {xx_n->Name()});
76-
op_desc.SetOutput("BatchedGate", {"blstm_0.tmp_2"});
77-
op_desc.SetOutput("BatchCellPreAct", {"blstm_1.tmp_2"});
77+
op_desc.SetOutput("BatchedInput", {"blstm_0.tmp_2"});
7878
op_desc.SetAttr("is_reverse", lstm_n->Op()->GetAttr("is_reverse"));
7979
op_desc.SetAttr("use_peepholes", false);
80+
81+
#define TMP_NAME(x) "at.new.tmp." #x
82+
#define OP_SET_OUT(x) op_desc.SetOutput(#x, {TMP_NAME(x)})
83+
OP_SET_OUT(BatchedCell);
84+
OP_SET_OUT(BatchedHidden);
85+
OP_SET_OUT(ReorderedH0);
86+
OP_SET_OUT(ReorderedC0);
87+
#undef OP_SET_OUT
8088
auto* op = graph->CreateOpNode(&op_desc);
8189

90+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
91+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
92+
93+
#define TMP_NEW(x) scope->Var(TMP_NAME(x))->GetMutable<LoDTensor>()
94+
TMP_NEW(BatchedCell);
95+
TMP_NEW(BatchedHidden);
96+
TMP_NEW(ReorderedH0);
97+
TMP_NEW(ReorderedC0);
98+
99+
#undef TMP_NEW
100+
#undef TMP_NAME
101+
82102
#define LINK_TO(a, b) \
83103
a->outputs.push_back(b); \
84104
b->inputs.push_back(a);
@@ -89,7 +109,6 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
89109
LINK_TO(op, hidden_n);
90110
#undef LINK_TO
91111
return op;
92-
93112
};
94113

95114
lstm_creator(16, 12, 14, 18, 17, 22, 21, 19);
@@ -105,14 +124,16 @@ std::unique_ptr<ir::Graph> FCLstmFusePass::ApplyImpl(
105124
for (auto it = node->inputs.begin(); it != node->inputs.end();) {
106125
if (marked_nodes.count(*it)) {
107126
it = const_cast<Node*>(node)->inputs.erase(it);
108-
} else
127+
} else {
109128
it++;
129+
}
110130
}
111131
for (auto it = node->outputs.begin(); it != node->outputs.end();) {
112132
if (marked_nodes.count(*it)) {
113133
it = const_cast<Node*>(node)->outputs.erase(it);
114-
} else
134+
} else {
115135
it++;
136+
}
116137
}
117138
}
118139

paddle/fluid/operators/fusion_lstm_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ limitations under the License. */
2222
#include "paddle/fluid/operators/math/sequence2batch.h"
2323
#include "paddle/fluid/platform/cpu_info.h"
2424

25-
DEFINE_bool(seq_mode, false, "Use sequence mode");
25+
DEFINE_bool(seq_mode, true, "Use sequence mode");
2626

2727
namespace paddle {
2828
namespace operators {

0 commit comments

Comments
 (0)