Skip to content

Commit 664159a

Browse files
authored
Merge pull request #13998 from tensor-tang/fea/fusion_seqconv_add
Fea/fusion seqconv eltadd relu
2 parents 765085d + 40f8456 commit 664159a

16 files changed

+819
-72
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
3737
pass_library(fc_gru_fuse_pass inference)
3838
pass_library(seq_concat_fc_fuse_pass inference)
3939
pass_library(conv_bn_fuse_pass inference)
40+
pass_library(seqconv_eltadd_relu_fuse_pass inference)
4041
if(WITH_MKLDNN)
4142
pass_library(mkldnn_placement_pass base)
4243
pass_library(conv_bias_mkldnn_fuse_pass inference)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -761,6 +761,51 @@ PDNode *patterns::ConvReLU::operator()(
761761
return relu_out_var;
762762
}
763763

764+
PDNode *patterns::SeqConvEltAddRelu::operator()(
765+
paddle::framework::ir::PDNode *seqconv_input) {
766+
// Create Operators
767+
seqconv_input->assert_is_op_input("sequence_conv", "X");
768+
auto *seqconv_op = pattern->NewNode(seqconv_repr())
769+
->assert_is_op("sequence_conv")
770+
->assert_op_attr<bool>("paddingTrainable", false)
771+
->assert_op_attr<int>("contextStride", 1);
772+
773+
auto *eltadd_op =
774+
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");
775+
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
776+
// Create variables
777+
// Filter
778+
auto *seqconv_weight_var =
779+
pattern->NewNode(seqconv_weight_repr())
780+
->AsInput()
781+
->assert_is_persistable_var()
782+
->assert_is_op_input("sequence_conv", "Filter");
783+
// Bias
784+
auto *eltadd_bias_var = pattern->NewNode(eltadd_bias_repr())
785+
->AsInput()
786+
->assert_is_op_input("elementwise_add");
787+
// intermediate variable, will be removed in the IR after fuse.
788+
auto *seqconv_out_var = pattern->NewNode(seqconv_out_repr())
789+
->AsIntermediate()
790+
->assert_is_only_output_of_op("sequence_conv")
791+
->assert_is_op_input("elementwise_add");
792+
auto *eltadd_out_var = pattern->NewNode(eltadd_out_repr())
793+
->AsIntermediate()
794+
->assert_is_only_output_of_op("elementwise_add")
795+
->assert_is_only_input_of_op("relu");
796+
// output
797+
auto *relu_out_var = pattern->NewNode(relu_out_repr())
798+
->AsOutput()
799+
->assert_is_op_output("relu");
800+
801+
seqconv_op->LinksFrom({seqconv_input, seqconv_weight_var})
802+
.LinksTo({seqconv_out_var});
803+
eltadd_op->LinksFrom({seqconv_out_var, eltadd_bias_var})
804+
.LinksTo({eltadd_out_var});
805+
relu_op->LinksFrom({eltadd_out_var}).LinksTo({relu_out_var});
806+
return relu_out_var;
807+
}
808+
764809
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
765810
bool with_bias) {
766811
// Create shared nodes.

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,15 @@ struct PDNode {
128128
const std::unordered_set<std::string>& op_types,
129129
const std::string& argument, int nth);
130130

131+
template <typename T>
132+
PDNode* assert_op_attr(const std::string& attr_name, const T& attr) {
133+
asserts_.emplace_back([=](Node* x) {
134+
return x && x->IsOp() && x->Op()->HasAttr(attr_name) &&
135+
boost::get<T>(x->Op()->GetAttr(attr_name)) == attr;
136+
});
137+
return this;
138+
}
139+
131140
private:
132141
PDNode(PDPattern* pattern, const std::string& name = "",
133142
Type type = Type::kVar)
@@ -434,6 +443,31 @@ struct ConvReLU : public PatternBase {
434443
PATTERN_DECL_NODE(relu_out);
435444
};
436445

446+
// SEQCONV with Elementwise_Add ReLU
447+
// op: seqconv + elementwise_add + relu
448+
// named nodes:
449+
// seqconv_input, seqconv_weight,
450+
// seqconv_out, seqconv,
451+
// elementwise_add_bias, elementwise_add_out, elementwise_add
452+
// relu_out, relu
453+
struct SeqConvEltAddRelu : public PatternBase {
454+
SeqConvEltAddRelu(PDPattern* pattern, const std::string& name_scope)
455+
: PatternBase(pattern, name_scope, "seqconv_eltadd_relu") {}
456+
457+
PDNode* operator()(PDNode* seqconv_input);
458+
459+
// declare operator node's name
460+
PATTERN_DECL_NODE(seqconv);
461+
PATTERN_DECL_NODE(eltadd);
462+
PATTERN_DECL_NODE(relu);
463+
// declare variable node's name
464+
PATTERN_DECL_NODE(seqconv_weight);
465+
PATTERN_DECL_NODE(seqconv_out);
466+
PATTERN_DECL_NODE(eltadd_bias);
467+
PATTERN_DECL_NODE(eltadd_out);
468+
PATTERN_DECL_NODE(relu_out);
469+
};
470+
437471
// FC with bias
438472
// op: mul + elementwise_add
439473
// named nodes:
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/lod_tensor.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
24+
GraphPatternDetector gpd;
25+
auto* pattern = gpd.mutable_pattern();
26+
27+
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
28+
->assert_is_op_input("sequence_conv")
29+
->assert_var_not_persistable();
30+
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
31+
fuse_pattern(x);
32+
33+
// Create New OpDesc
34+
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
35+
Node* eltadd_bias, Node* relu_out) {
36+
OpDesc op_desc;
37+
op_desc.SetType("fusion_seqconv_eltadd_relu");
38+
op_desc.SetInput("X", {input->Name()});
39+
op_desc.SetInput("Filter", {seqconv_weight->Name()});
40+
op_desc.SetInput("Bias", {eltadd_bias->Name()});
41+
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
42+
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
43+
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
44+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
45+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
46+
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
47+
op_desc.SetOutput("ColMat", {ColMat});
48+
op_desc.SetOutput("Out", {relu_out->Name()});
49+
scope->Var(ColMat)->GetMutable<LoDTensor>();
50+
51+
auto* op = graph->CreateOpNode(&op_desc);
52+
IR_NODE_LINK_TO(input, op);
53+
IR_NODE_LINK_TO(seqconv_weight, op);
54+
IR_NODE_LINK_TO(eltadd_bias, op);
55+
IR_NODE_LINK_TO(op, relu_out);
56+
return op;
57+
};
58+
59+
int fusion_count{0};
60+
61+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
62+
Graph* g) {
63+
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
64+
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
65+
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
66+
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
67+
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
68+
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
69+
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
70+
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
71+
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
72+
73+
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
74+
relu_out);
75+
std::unordered_set<const Node*> marked_nodes(
76+
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
77+
GraphSafeRemoveNodes(graph, marked_nodes);
78+
++fusion_count;
79+
};
80+
81+
gpd(graph, handler);
82+
83+
return fusion_count;
84+
}
85+
86+
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
87+
std::unique_ptr<ir::Graph> graph) const {
88+
FusePassBase::Init(name_scope_, graph.get());
89+
90+
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
91+
AddStatis(fusion_count);
92+
93+
return graph;
94+
}
95+
96+
} // namespace ir
97+
} // namespace framework
98+
} // namespace paddle
99+
100+
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
101+
paddle::framework::ir::SeqConvEltAddReluFusePass);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
class SeqConvEltAddReluFusePass : public FusePassBase {
27+
public:
28+
virtual ~SeqConvEltAddReluFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
33+
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
34+
};
35+
36+
} // namespace ir
37+
} // namespace framework
38+
} // namespace paddle

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,18 @@ class Analyzer : public OrderedRegistry<PassManager> {
6767
// larger fusion.
6868
const std::vector<std::string> all_ir_passes_{{
6969
// Manual update the passes here.
70-
"infer_clean_graph_pass", //
71-
"attention_lstm_fuse_pass", //
72-
"embedding_fc_lstm_fuse_pass", //
73-
"fc_lstm_fuse_pass", //
74-
"mul_lstm_fuse_pass", //
75-
"fc_gru_fuse_pass", //
76-
"mul_gru_fuse_pass", //
77-
"seq_concat_fc_fuse_pass", //
78-
"fc_fuse_pass", //
79-
"conv_bn_fuse_pass", //
80-
"conv_eltwiseadd_bn_fuse_pass", //
70+
"infer_clean_graph_pass", //
71+
"attention_lstm_fuse_pass", //
72+
"seqconv_eltadd_relu_fuse_pass", //
73+
"embedding_fc_lstm_fuse_pass", //
74+
"fc_lstm_fuse_pass", //
75+
"mul_lstm_fuse_pass", //
76+
"fc_gru_fuse_pass", //
77+
"mul_gru_fuse_pass", //
78+
"seq_concat_fc_fuse_pass", //
79+
"fc_fuse_pass", //
80+
"conv_bn_fuse_pass", //
81+
"conv_eltwiseadd_bn_fuse_pass", //
8182
#ifdef PADDLE_WITH_MKLDNN
8283
"conv_bias_mkldnn_fuse_pass", //
8384
"conv_relu_mkldnn_fuse_pass", //

paddle/fluid/inference/tests/api/analyzer_seq_conv1_tester.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,13 @@ TEST(Analyzer_seq_conv1, fuse_statis) {
183183
SetConfig(&cfg);
184184
int num_ops;
185185
auto predictor = CreatePaddlePredictor<AnalysisConfig>(cfg);
186-
GetFuseStatis(predictor.get(), &num_ops);
186+
187+
auto fuse_statis = GetFuseStatis(predictor.get(), &num_ops);
188+
ASSERT_TRUE(fuse_statis.count("fc_fuse"));
189+
ASSERT_TRUE(fuse_statis.count("seqconv_eltadd_relu_fuse"));
190+
EXPECT_EQ(fuse_statis.at("fc_fuse"), 2);
191+
EXPECT_EQ(fuse_statis.at("seqconv_eltadd_relu_fuse"), 6);
192+
EXPECT_EQ(num_ops, 32);
187193
}
188194

189195
// Compare result of NativeConfig and AnalysisConfig

paddle/fluid/operators/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ function(op_library TARGET)
8686
# remove windows unsupported op, because windows has no nccl, no warpctc such ops.
8787
foreach(windows_unsupport_op "nccl_op" "gen_nccl_id_op" "warpctc_op" "hierarchical_sigmoid_op"
8888
"crf_decoding_op" "select_op" "lstmp_op" "gru_op" "fusion_gru_op" "lstm_op" "fusion_lstm_op" "cumsum_op"
89-
"channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
89+
"fusion_seqconv_eltadd_relu_op" "channel_send_op" "channel_create_op" "channel_close_op" "channel_recv_op")
9090
if ("${TARGET}" STREQUAL "${windows_unsupport_op}")
9191
return()
9292
endif()

0 commit comments

Comments
 (0)