Skip to content

Commit fae7981

Browse files
authored
Fusion: seqpool_cvm_concat, test=release/1.5 (#19381)
* add fusion_seqpool_cvm_concat test=develop * simplify pass, test=develop * fix code style, test=develop
1 parent c737116 commit fae7981

File tree

8 files changed

+763
-0
lines changed

8 files changed

+763
-0
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ pass_library(multi_batch_merge_pass base)
6161
pass_library(conv_bn_fuse_pass inference)
6262
pass_library(seqconv_eltadd_relu_fuse_pass inference)
6363
pass_library(seqpool_concat_fuse_pass inference)
64+
pass_library(seqpool_cvm_concat_fuse_pass inference)
6465
pass_library(repeated_fc_relu_fuse_pass inference)
6566
pass_library(squared_mat_sub_fuse_pass inference)
6667
pass_library(is_test_pass base)
@@ -118,6 +119,7 @@ cc_test(graph_to_program_pass_test SRCS graph_to_program_pass_test.cc DEPS graph
118119
cc_test(test_graph_pattern_detector SRCS graph_pattern_detector_tester.cc DEPS graph_pattern_detector)
119120
cc_test(test_fc_fuse_pass SRCS fc_fuse_pass_tester.cc DEPS fc_fuse_pass framework_proto)
120121
cc_test(test_seqpool_concat_fuse_pass SRCS seqpool_concat_fuse_pass_tester.cc DEPS seqpool_concat_fuse_pass framework_proto)
122+
cc_test(test_seqpool_cvm_concat_fuse_pass SRCS seqpool_cvm_concat_fuse_pass_tester.cc DEPS seqpool_cvm_concat_fuse_pass framework_proto)
121123
cc_test(test_is_test_pass SRCS is_test_pass_tester.cc DEPS is_test_pass)
122124
if(NOT WIN32)
123125
cc_test(test_sync_batch_norm_pass SRCS sync_batch_norm_pass_tester.cc DEPS sync_batch_norm_pass)
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#include "paddle/fluid/framework/ir/seqpool_cvm_concat_fuse_pass.h"
16+
#include <string>
17+
#include <unordered_map>
18+
#include <unordered_set>
19+
#include <vector>
20+
#include "paddle/fluid/framework/lod_tensor.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
namespace {
27+
static PDNode* BuildCVMConcatPattern(PDPattern* pattern) {
28+
auto cvm_behind_x = [](Node* x) -> bool {
29+
Node* adj = x->inputs[0];
30+
Node* alt = x->inputs[0]->inputs[0];
31+
return x && adj && adj->IsVar() && alt->IsOp() &&
32+
alt->Op()->Type() == "cvm";
33+
};
34+
auto* concat_op_node = pattern->NewNode("concat_op")
35+
->assert_is_op("concat")
36+
->assert_op_attr<int>("axis", 1)
37+
->assert_more(cvm_behind_x);
38+
return concat_op_node;
39+
}
40+
41+
static void GetConcatNodes(ir::Graph* graph, std::vector<Node*>* concat_nodes) {
42+
GraphPatternDetector gpd;
43+
auto* pattern = gpd.mutable_pattern();
44+
auto concat_op_node = BuildCVMConcatPattern(pattern);
45+
GraphPatternDetector::handle_t handler = [&](
46+
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
47+
Node* concat_op = subgraph.at(concat_op_node);
48+
concat_nodes->push_back(concat_op);
49+
};
50+
gpd(graph, handler);
51+
}
52+
} // anonymous namespace
53+
54+
void SeqPoolCVMConcatFusePass::ApplyImpl(ir::Graph* graph) const {
55+
FusePassBase::Init("seqpool_cvm_concat_fuse", graph);
56+
std::vector<Node*> concat_nodes;
57+
GetConcatNodes(graph, &concat_nodes);
58+
59+
int count = 0;
60+
for (auto* concat_node : concat_nodes) {
61+
GraphPatternDetector gpd;
62+
auto* pattern = gpd.mutable_pattern();
63+
auto concat_before_x = [=](Node* x) -> bool {
64+
return x && x->outputs[0] == concat_node;
65+
};
66+
PDNode* seqpool_in_var_node =
67+
pattern->NewNode("seqpool_in_var")
68+
->assert_is_only_input_of_op("sequence_pool");
69+
PDNode* seqpool_op_node =
70+
pattern->NewNode("seqpool_op")
71+
->assert_is_op("sequence_pool")
72+
->assert_op_attr<std::string>("pooltype", "SUM");
73+
PDNode* seqpool_out_var_node =
74+
pattern->NewNode("seqpool_out_var")
75+
->assert_is_op_nth_output("sequence_pool", "Out", 0)
76+
->assert_is_op_nth_input("cvm", "X", 0);
77+
PDNode* seqpool_idx_out_var_node =
78+
pattern->NewNode("seqpool_idx_out_var")
79+
->assert_is_op_nth_output("sequence_pool", "MaxIndex", 0);
80+
PDNode* cvm_op_node =
81+
pattern->NewNode("cvm_op")->assert_is_op("cvm")->assert_op_attr<bool>(
82+
"use_cvm", true);
83+
PDNode* cvm_out_var_node = pattern->NewNode("cvm_op_out_var")
84+
->assert_is_op_nth_output("cvm", "Y", 0)
85+
->assert_more(concat_before_x);
86+
PDNode* cvm_cvm_in_var_node = pattern->NewNode("cvm_cvm_in_var")
87+
->assert_is_op_nth_input("cvm", "CVM", 0);
88+
89+
seqpool_op_node->LinksFrom({seqpool_in_var_node})
90+
.LinksTo({seqpool_out_var_node, seqpool_idx_out_var_node});
91+
seqpool_out_var_node->LinksFrom({seqpool_op_node}).LinksTo({cvm_op_node});
92+
cvm_op_node->LinksTo({cvm_out_var_node})
93+
.LinksFrom({cvm_cvm_in_var_node, seqpool_out_var_node});
94+
95+
std::unordered_map<std::string, Node*> ins_to_concat;
96+
std::vector<Node*> subgraph_ins;
97+
std::vector<std::string> subgraph_ins_name;
98+
std::unordered_set<const Node*> marked_nodes;
99+
100+
Node* cvm_input_of_cvm;
101+
Node* concat_out_var = concat_node->outputs[0];
102+
103+
GraphPatternDetector::handle_t handler = [&](
104+
const GraphPatternDetector::subgraph_t& subgraph, Graph* graph) {
105+
Node* seqpool_in_var = subgraph.at(seqpool_in_var_node);
106+
Node* seqpool_op = subgraph.at(seqpool_op_node);
107+
Node* seqpool_out_var = subgraph.at(seqpool_out_var_node);
108+
Node* seqpool_idx_out_var = subgraph.at(seqpool_idx_out_var_node);
109+
Node* cvm_op = subgraph.at(cvm_op_node);
110+
Node* cvm_out_var = subgraph.at(cvm_out_var_node);
111+
cvm_input_of_cvm = subgraph.at(cvm_cvm_in_var_node);
112+
marked_nodes.insert({seqpool_op, seqpool_out_var, seqpool_idx_out_var,
113+
cvm_op, cvm_out_var, concat_node});
114+
ins_to_concat[cvm_out_var->Name()] = seqpool_in_var;
115+
};
116+
gpd(graph, handler);
117+
118+
if (!ins_to_concat.empty()) {
119+
for (const auto* in : concat_node->inputs) {
120+
subgraph_ins.push_back(ins_to_concat.at(in->Name()));
121+
subgraph_ins_name.push_back(ins_to_concat.at(in->Name())->Name());
122+
}
123+
124+
// Create New OpDesc
125+
OpDesc op_desc;
126+
op_desc.SetType("fusion_seqpool_cvm_concat");
127+
op_desc.SetInput("X", subgraph_ins_name);
128+
op_desc.SetInput("CVM", {cvm_input_of_cvm->Name()});
129+
op_desc.SetAttr("pooltype", std::string("SUM"));
130+
op_desc.SetAttr("use_cvm", true);
131+
op_desc.SetAttr("axis", concat_node->Op()->GetAttr("axis"));
132+
op_desc.SetOutput("Out", {concat_out_var->Name()});
133+
auto* op = graph->CreateOpNode(&op_desc);
134+
135+
for (size_t i = 0; i < subgraph_ins.size(); ++i) {
136+
IR_NODE_LINK_TO(subgraph_ins[i], op);
137+
}
138+
IR_NODE_LINK_TO(cvm_input_of_cvm, op);
139+
IR_NODE_LINK_TO(op, concat_out_var);
140+
141+
GraphSafeRemoveNodes(graph, marked_nodes);
142+
count++;
143+
}
144+
}
145+
AddStatis(count);
146+
}
147+
148+
} // namespace ir
149+
} // namespace framework
150+
} // namespace paddle
151+
152+
REGISTER_PASS(seqpool_cvm_concat_fuse_pass,
153+
paddle::framework::ir::SeqPoolCVMConcatFusePass);
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
*
3+
* Licensed under the Apache License, Version 2.0 (the "License");
4+
* you may not use this file except in compliance with the License.
5+
* You may obtain a copy of the License at
6+
*
7+
* http://www.apache.org/licenses/LICENSE-2.0
8+
*
9+
* Unless required by applicable law or agreed to in writing, software
10+
* distributed under the License is distributed on an "AS IS" BASIS,
11+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
* See the License for the specific language governing permissions and
13+
* limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
/**
27+
* Fuse SequencePool(with sum pooltype yet) and Concat;
28+
*
29+
* Before fuse:
30+
* | | |
31+
* seq_pool, seq_pool, ... seq_pool
32+
* | | |
33+
* cvm cvm cvm
34+
* \ | ... /
35+
* concat
36+
* |
37+
* After fuse:
38+
* \ | /
39+
* FusionSeqPoolCVMConcat
40+
* |
41+
*/
42+
class SeqPoolCVMConcatFusePass : public FusePassBase {
43+
public:
44+
virtual ~SeqPoolCVMConcatFusePass() {}
45+
46+
protected:
47+
void ApplyImpl(ir::Graph* graph) const override;
48+
49+
const std::string name_scope_{"seqpool_cvm_concat_fuse"};
50+
};
51+
52+
} // namespace ir
53+
} // namespace framework
54+
} // namespace paddle

0 commit comments

Comments
 (0)