Skip to content

Commit 603ba5e

Browse files
committed
add seqconv eltadd relu pass
1 parent 23fc896 commit 603ba5e

File tree

6 files changed

+227
-11
lines changed

6 files changed

+227
-11
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ pass_library(embedding_fc_lstm_fuse_pass inference)
3737
pass_library(fc_gru_fuse_pass inference)
3838
pass_library(seq_concat_fc_fuse_pass inference)
3939
pass_library(conv_bn_fuse_pass inference)
40+
pass_library(seqconv_eltadd_relu_fuse_pass inference)
4041
if(WITH_MKLDNN)
4142
pass_library(mkldnn_placement_pass base)
4243
pass_library(conv_relu_mkldnn_fuse_pass inference)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,11 @@ PDNode *PDNode::assert_is_op() {
349349
return this;
350350
}
351351

352+
// PDNode *PDNode::assert_op_attr() {
353+
// asserts_.emplace_back([](Node *x) { return x && x->IsOp(); });
354+
// return this;
355+
// }
356+
352357
PDNode *PDNode::assert_is_op(const std::string &op_type) {
353358
asserts_.emplace_back([op_type](Node *x) {
354359
return x && x->IsOp() && x->Op()->Type() == op_type;
@@ -761,6 +766,51 @@ PDNode *patterns::ConvReLU::operator()(
761766
return relu_out_var;
762767
}
763768

769+
PDNode *patterns::SeqConvEltAddRelu::operator()(
770+
paddle::framework::ir::PDNode *seqconv_input) {
771+
// Create Operators
772+
seqconv_input->assert_is_op_input("sequence_conv", "X");
773+
auto *seqconv_op =
774+
pattern->NewNode(seqconv_repr())->assert_is_op("sequence_conv");
775+
// ->assert_op_attr("paddingTrainable", false)
776+
// ->assert_op_attr("contextStride", 1)
777+
778+
auto *eltadd_op =
779+
pattern->NewNode(eltadd_repr())->assert_is_op("elementwise_add");
780+
auto *relu_op = pattern->NewNode(relu_repr())->assert_is_op("relu");
781+
// Create variables
782+
// Filter
783+
auto *seqconv_weight_var =
784+
pattern->NewNode(seqconv_weight_repr())
785+
->AsInput()
786+
->assert_is_persistable_var()
787+
->assert_is_op_input("sequence_conv", "Filter");
788+
// Bias
789+
auto *eltadd_bias_var = pattern->NewNode(eltadd_bias_repr())
790+
->AsInput()
791+
->assert_is_op_input("elementwise_add");
792+
// intermediate variable, will be removed in the IR after fuse.
793+
auto *seqconv_out_var = pattern->NewNode(seqconv_out_repr())
794+
->AsIntermediate()
795+
->assert_is_only_output_of_op("sequence_conv")
796+
->assert_is_op_input("elementwise_add");
797+
auto *eltadd_out_var = pattern->NewNode(eltadd_out_repr())
798+
->AsIntermediate()
799+
->assert_is_only_output_of_op("elementwise_add")
800+
->assert_is_only_input_of_op("relu");
801+
// output
802+
auto *relu_out_var = pattern->NewNode(relu_out_repr())
803+
->AsOutput()
804+
->assert_is_op_output("relu");
805+
806+
seqconv_op->LinksFrom({seqconv_input, seqconv_weight_var})
807+
.LinksTo({seqconv_out_var});
808+
eltadd_op->LinksFrom({seqconv_out_var, eltadd_bias_var})
809+
.LinksTo({eltadd_out_var});
810+
relu_op->LinksFrom({eltadd_out_var}).LinksTo({relu_out_var});
811+
return relu_out_var;
812+
}
813+
764814
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
765815
bool with_bias) {
766816
// Create shared nodes.

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,31 @@ struct ConvReLU : public PatternBase {
434434
PATTERN_DECL_NODE(relu_out);
435435
};
436436

437+
// SEQCONV with Elementwise_Add ReLU
438+
// op: seqconv + elementwise_add + relu
439+
// named nodes:
440+
// seqconv_input, seqconv_weight,
441+
// seqconv_out, seqconv,
442+
// elementwise_add_bias, elementwise_add_out, elementwise_add
443+
// relu_out, relu
444+
struct SeqConvEltAddRelu : public PatternBase {
445+
SeqConvEltAddRelu(PDPattern* pattern, const std::string& name_scope)
446+
: PatternBase(pattern, name_scope, "seqconv_eltadd_relu") {}
447+
448+
PDNode* operator()(PDNode* seqconv_input);
449+
450+
// declare operator node's name
451+
PATTERN_DECL_NODE(seqconv);
452+
PATTERN_DECL_NODE(eltadd);
453+
PATTERN_DECL_NODE(relu);
454+
// declare variable node's name
455+
PATTERN_DECL_NODE(seqconv_weight);
456+
PATTERN_DECL_NODE(seqconv_out);
457+
PATTERN_DECL_NODE(eltadd_bias);
458+
PATTERN_DECL_NODE(eltadd_out);
459+
PATTERN_DECL_NODE(relu_out);
460+
};
461+
437462
// FC with bias
438463
// op: mul + elementwise_add
439464
// named nodes:
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/seqconv_eltadd_relu_fuse_pass.h"
16+
#include <string>
17+
#include "paddle/fluid/framework/lod_tensor.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
int BuildFusion(Graph* graph, const std::string& name_scope, Scope* scope) {
24+
GraphPatternDetector gpd;
25+
auto* pattern = gpd.mutable_pattern();
26+
27+
PDNode* x = pattern->NewNode(patterns::PDNodeName(name_scope, "X"))
28+
->assert_is_op_input("sequence_conv")
29+
->assert_var_not_persistable();
30+
patterns::SeqConvEltAddRelu fuse_pattern(pattern, name_scope);
31+
fuse_pattern(x);
32+
33+
// Create New OpDesc
34+
auto fuse_creator = [&](Node* seqconv, Node* input, Node* seqconv_weight,
35+
Node* eltadd_bias, Node* relu_out) {
36+
OpDesc op_desc;
37+
op_desc.SetType("fusion_seqconv_eltadd_relu");
38+
op_desc.SetInput("X", {input->Name()});
39+
op_desc.SetInput("Filter", {seqconv_weight->Name()});
40+
op_desc.SetInput("Bias", {eltadd_bias->Name()});
41+
op_desc.SetAttr("contextLength", seqconv->Op()->GetAttr("contextLength"));
42+
op_desc.SetAttr("contextStart", seqconv->Op()->GetAttr("contextStart"));
43+
op_desc.SetAttr("contextStride", seqconv->Op()->GetAttr("contextStride"));
44+
PADDLE_ENFORCE(graph->Has(kParamScopeAttr));
45+
auto* scope = graph->Get<Scope*>(kParamScopeAttr);
46+
const std::string ColMat = patterns::UniqueKey("SeqConvColMat");
47+
op_desc.SetOutput("ColMat", {ColMat});
48+
op_desc.SetOutput("Out", {relu_out->Name()});
49+
scope->Var(ColMat)->GetMutable<LoDTensor>();
50+
51+
auto* op = graph->CreateOpNode(&op_desc);
52+
IR_NODE_LINK_TO(input, op);
53+
IR_NODE_LINK_TO(seqconv_weight, op);
54+
IR_NODE_LINK_TO(eltadd_bias, op);
55+
IR_NODE_LINK_TO(op, relu_out);
56+
return op;
57+
};
58+
59+
int fusion_count{0};
60+
61+
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
62+
Graph* g) {
63+
VLOG(4) << "handle SeqConv EltAdd Relu fuse";
64+
GET_IR_NODE_FROM_SUBGRAPH(seqconv, seqconv, fuse_pattern);
65+
GET_IR_NODE_FROM_SUBGRAPH(seqconv_weight, seqconv_weight, fuse_pattern);
66+
GET_IR_NODE_FROM_SUBGRAPH(seqconv_out, seqconv_out, fuse_pattern);
67+
GET_IR_NODE_FROM_SUBGRAPH(eltadd, eltadd, fuse_pattern);
68+
GET_IR_NODE_FROM_SUBGRAPH(eltadd_bias, eltadd_bias, fuse_pattern);
69+
GET_IR_NODE_FROM_SUBGRAPH(eltadd_out, eltadd_out, fuse_pattern);
70+
GET_IR_NODE_FROM_SUBGRAPH(relu, relu, fuse_pattern);
71+
GET_IR_NODE_FROM_SUBGRAPH(relu_out, relu_out, fuse_pattern);
72+
73+
fuse_creator(seqconv, subgraph.at(x), seqconv_weight, eltadd_bias,
74+
relu_out);
75+
std::unordered_set<const Node*> marked_nodes(
76+
{seqconv, seqconv_out, eltadd, eltadd_out, relu});
77+
GraphSafeRemoveNodes(graph, marked_nodes);
78+
++fusion_count;
79+
};
80+
81+
gpd(graph, handler);
82+
83+
return fusion_count;
84+
}
85+
86+
std::unique_ptr<ir::Graph> SeqConvEltAddReluFusePass::ApplyImpl(
87+
std::unique_ptr<ir::Graph> graph) const {
88+
FusePassBase::Init(name_scope_, graph.get());
89+
90+
int fusion_count = BuildFusion(graph.get(), name_scope_, param_scope());
91+
AddStatis(fusion_count);
92+
93+
return graph;
94+
}
95+
96+
} // namespace ir
97+
} // namespace framework
98+
} // namespace paddle
99+
100+
REGISTER_PASS(seqconv_eltadd_relu_fuse_pass,
101+
paddle::framework::ir::SeqConvEltAddReluFusePass);
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
19+
#include "paddle/fluid/framework/ir/graph.h"
20+
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
21+
22+
namespace paddle {
23+
namespace framework {
24+
namespace ir {
25+
26+
class SeqConvEltAddReluFusePass : public FusePassBase {
27+
public:
28+
virtual ~SeqConvEltAddReluFusePass() {}
29+
30+
protected:
31+
std::unique_ptr<ir::Graph> ApplyImpl(std::unique_ptr<ir::Graph> graph) const;
32+
33+
const std::string name_scope_{"seqconv_eltadd_relu_fuse"};
34+
};
35+
36+
} // namespace ir
37+
} // namespace framework
38+
} // namespace paddle

paddle/fluid/inference/analysis/analyzer.h

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,18 @@ class Analyzer : public OrderedRegistry<PassManager> {
6767
// larger fusion.
6868
const std::vector<std::string> all_ir_passes_{{
6969
// Manual update the passes here.
70-
"infer_clean_graph_pass", //
71-
"attention_lstm_fuse_pass", //
72-
"embedding_fc_lstm_fuse_pass", //
73-
"fc_lstm_fuse_pass", //
74-
"mul_lstm_fuse_pass", //
75-
"fc_gru_fuse_pass", //
76-
"mul_gru_fuse_pass", //
77-
"seq_concat_fc_fuse_pass", //
78-
"fc_fuse_pass", //
79-
"conv_bn_fuse_pass", //
80-
"conv_eltwiseadd_bn_fuse_pass", //
70+
"infer_clean_graph_pass", //
71+
"attention_lstm_fuse_pass", //
72+
"seqconv_eltadd_relu_fuse_pass", //
73+
"embedding_fc_lstm_fuse_pass", //
74+
"fc_lstm_fuse_pass", //
75+
"mul_lstm_fuse_pass", //
76+
"fc_gru_fuse_pass", //
77+
"mul_gru_fuse_pass", //
78+
"seq_concat_fc_fuse_pass", //
79+
"fc_fuse_pass", //
80+
"conv_bn_fuse_pass", //
81+
"conv_eltwiseadd_bn_fuse_pass", //
8182
#ifdef PADDLE_WITH_MKLDNN
8283
"conv_relu_mkldnn_fuse_pass", //
8384
#endif

0 commit comments

Comments
 (0)