Skip to content

Commit 796c87d

Browse files
authored
bugfix/fusion lstm (#13185)
1 parent 9557cc2 commit 796c87d

File tree

6 files changed

+37
-19
lines changed

6 files changed

+37
-19
lines changed

paddle/fluid/framework/ir/fc_fuse_pass.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ bool LinksReplace(std::vector<Node*>* links, Node* from, Node* to) {
7777
std::unique_ptr<ir::Graph> FCFusePass::ApplyImpl(
7878
std::unique_ptr<ir::Graph> graph) const {
7979
PADDLE_ENFORCE(graph.get());
80-
FusePassBase::Init("fc", graph.get());
80+
FusePassBase::Init("fc_fuse", graph.get());
8181

8282
std::unordered_set<Node*> nodes2delete;
8383

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ bool GraphPatternDetector::MarkPDNodesInGraph(const ir::Graph& graph) {
111111
return false;
112112
}
113113
}
114+
for (auto& item : pdnodes2nodes_) {
115+
for (auto& n : item.second) {
116+
GetMarkedNodes(const_cast<Graph*>(&graph)).insert(n);
117+
}
118+
}
114119
VLOG(3) << pdnodes2nodes_.size() << " nodes marked";
115120

116121
return !pdnodes2nodes_.empty();
@@ -278,7 +283,7 @@ void GraphPatternDetector::RemoveOverlappedMatch(
278283
for (const auto& subgraph : *subgraphs) {
279284
bool valid = true;
280285
for (auto& item : subgraph) {
281-
if (node_set.count(item.second)) {
286+
if (item.first->IsIntermediate() && node_set.count(item.second)) {
282287
valid = false;
283288
break;
284289
}

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,8 @@ class GraphPatternDetector {
245245
void UniquePatterns(std::vector<subgraph_t>* subgraphs);
246246

247247
// Remove overlapped match subgraphs, when overlapped, keep the previous one.
248+
// The intermediate PDNodes will be removed, so can't shared by multiple
249+
// patterns.
248250
void RemoveOverlappedMatch(std::vector<subgraph_t>* subgraphs);
249251

250252
// Validate whether the intermediate nodes are linked by external nodes.

paddle/fluid/framework/ir/graph_pattern_detector_tester.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,9 @@ TEST(GraphPatternDetecter, MultiSubgraph) {
140140
return node->IsOp() && (node->Name() == "op2" || node->Name() == "op3");
141141
},
142142
"OP0");
143-
auto* any_var = x.mutable_pattern()->NewNode(
144-
[](Node* node) { return node->IsVar(); }, "VAR");
143+
auto* any_var = x.mutable_pattern()
144+
->NewNode([](Node* node) { return node->IsVar(); }, "VAR")
145+
->AsIntermediate();
145146
auto* any_op1 = x.mutable_pattern()->NewNode(
146147
[](Node* node) { return node->IsOp(); }, "OP1");
147148

paddle/fluid/framework/ir/infer_clean_graph_pass.cc

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,42 +13,41 @@
1313
// limitations under the License.
1414

1515
#include <algorithm>
16+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
1617
#include "paddle/fluid/framework/ir/graph.h"
17-
#include "paddle/fluid/framework/ir/pass.h"
18+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
1819

1920
namespace paddle {
2021
namespace framework {
2122
namespace ir {
2223

23-
class InferCleanGraphPass : public Pass {
24+
class InferCleanGraphPass : public FusePassBase {
2425
public:
2526
virtual ~InferCleanGraphPass() {}
2627

2728
protected:
2829
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const {
30+
FusePassBase::Init("original_graph", graph.get());
2931
PADDLE_ENFORCE(graph.get());
3032

3133
auto is_valid_node = [](Node* x) {
3234
return x && IsControlDepVar(*x) && x->IsVar() && !x->Var();
3335
};
3436

35-
std::unordered_set<Node*> invalid_nodes;
37+
std::unordered_set<const Node*> invalid_nodes;
38+
int valid_op = 0;
3639
for (auto* node : graph->Nodes()) {
3740
if (is_valid_node(node)) {
3841
invalid_nodes.insert(node);
42+
} else if (node->IsOp()) {
43+
// Collect all the operators to help tracking number of operators.
44+
++valid_op;
3945
}
4046
}
4147

42-
// remove nodes from the graph.
43-
for (auto* node : invalid_nodes) {
44-
graph->RemoveNode(node);
45-
}
48+
GraphSafeRemoveNodes(graph.get(), invalid_nodes);
4649

47-
// clean edges.
48-
for (auto* node : graph->Nodes()) {
49-
CleanEdges(&node->inputs, invalid_nodes);
50-
CleanEdges(&node->outputs, invalid_nodes);
51-
}
50+
AddStatis(valid_op);
5251

5352
return graph;
5453
}

paddle/fluid/inference/analysis/analyzer_tester.cc

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -327,9 +327,20 @@ void TestDituRNNPrediction(const std::string &model_path,
327327
LOG(INFO) << "fused " << item.first << " " << item.second;
328328
}
329329

330-
ASSERT_TRUE(fuse_statis.count("fc"));
331-
EXPECT_EQ(fuse_statis.at("fc"), 1);
332-
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 1);
330+
int num_ops = 0;
331+
for (auto &node :
332+
analysis_predictor->analysis_argument().main_dfg->nodes.nodes()) {
333+
if (node->IsFunction()) {
334+
++num_ops;
335+
}
336+
}
337+
LOG(INFO) << "has num ops: " << num_ops;
338+
339+
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
340+
EXPECT_EQ(fuse_statis.at("fc_fuse"), 1);
341+
EXPECT_EQ(fuse_statis.at("fc_nobias_lstm_fuse"), 2); // bi-directional LSTM
342+
EXPECT_EQ(num_ops,
343+
13); // After graph optimization, only 13 operators exists.
333344
}
334345
}
335346

0 commit comments

Comments
 (0)