Skip to content

Commit 7da838b

Browse files
authored
[pass][OpenCL]add fuse flatten_contiguous_range and fc pass (#6040) (#6057)
* add fuse flatten_contiguous_range and fc pass * modify Copyright 2021 test=develop
1 parent 8fa948e commit 7da838b

File tree

8 files changed

+201
-0
lines changed

8 files changed

+201
-0
lines changed

lite/api/paddle_use_passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ USE_MIR_PASS(control_flow_op_shared_inputs_and_outputs_place_sync_pass);
7878
USE_MIR_PASS(lite_scale_activation_fuse_pass);
7979
USE_MIR_PASS(lite_instance_norm_activation_fuse_pass);
8080
USE_MIR_PASS(ssd_boxes_calc_offline_pass);
81+
USE_MIR_PASS(lite_flatten_fc_fuse_pass);
8182
USE_MIR_PASS(lite_fc_prelu_fuse_pass);
8283
USE_MIR_PASS(__xpu__graph_dedup_pass);
8384
USE_MIR_PASS(__xpu__resnet_fuse_pass);

lite/core/mir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ lite_cc_library(mir_passes
5858
fusion/sequence_reverse_embedding_fuse_pass.cc
5959
fusion/instance_norm_activation_fuse_pass.cc
6060
fusion/elementwise_add_scale_fuse_pass.cc
61+
fusion/flatten_fc_fuse_pass.cc
6162
fusion/fc_prelu_fuse_pass.cc
6263
elimination/identity_scale_eliminate_pass.cc
6364
elimination/identity_dropout_eliminate_pass.cc

lite/core/mir/fusion/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ lite_cc_library(fuse_instance_norm_activation
7070
lite_cc_library(fuse_elementwise_add_scale
7171
SRCS elementwise_add_scale_fuser.cc
7272
DEPS pattern_matcher_high_api)
73+
lite_cc_library(fuse_flatten_fc
74+
SRCS flatten_fc_fuser.cc
75+
DEPS pattern_matcher_high_api)
7376
lite_cc_library(fuse_fc_prelu
7477
SRCS fc_prelu_fuser.cc
7578
DEPS pattern_matcher_high_api)
@@ -98,6 +101,7 @@ set(mir_fusers
98101
fuse_sequence_reverse_embedding
99102
fuse_instance_norm_activation
100103
fuse_elementwise_add_scale
104+
fuse_flatten_fc
101105
fuse_fc_prelu
102106
fuse_conv_scale
103107
CACHE INTERNAL "fusers")
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/core/mir/fusion/flatten_fc_fuse_pass.h"
16+
#include <memory>
17+
#include <vector>
18+
#include "lite/core/mir/fusion/flatten_fc_fuser.h"
19+
#include "lite/core/mir/pass_registry.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace mir {
24+
25+
void FlattenFcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
26+
fusion::FlattenFcFuser flatten_fuser(" ");
27+
flatten_fuser(graph.get());
28+
}
29+
30+
} // namespace mir
31+
} // namespace lite
32+
} // namespace paddle
33+
34+
REGISTER_MIR_PASS(lite_flatten_fc_fuse_pass,
35+
paddle::lite::mir::FlattenFcFusePass)
36+
.BindTargets({TARGET(kOpenCL)})
37+
.BindKernel("fc");
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include "lite/core/mir/pass.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace mir {
24+
25+
class FlattenFcFusePass : public ProgramPass {
26+
public:
27+
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
28+
};
29+
30+
} // namespace mir
31+
} // namespace lite
32+
} // namespace paddle
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "lite/core/mir/fusion/flatten_fc_fuser.h"
16+
#include <memory>
17+
#include <vector>
18+
19+
namespace paddle {
20+
namespace lite {
21+
namespace mir {
22+
namespace fusion {
23+
24+
void FlattenFcFuser::BuildPattern() {
25+
// flatten_contiguous_range
26+
PMNode* x = VarNode("x")
27+
->assert_is_op_input("flatten_contiguous_range", "X")
28+
->AsInput();
29+
PMNode* flatten_contiguous_range =
30+
OpNode("flatten_contiguous_range", "flatten_contiguous_range")
31+
->AsIntermediate();
32+
PMNode* out = VarNode("output")
33+
->assert_is_op_output("flatten_contiguous_range", "Out")
34+
->AsIntermediate();
35+
PMNode* xshape =
36+
VarNode("xshape")
37+
->assert_is_op_output("flatten_contiguous_range", "XShape")
38+
->AsIntermediate();
39+
40+
// fc
41+
// PMNode* input = VarNode("input")->assert_is_op_input("fc",
42+
// "Input")->AsIntermediate();
43+
PMNode* weights =
44+
VarNode("weights")->assert_is_op_input("fc", "W")->AsInput();
45+
PMNode* bias = VarNode("bias")->assert_is_op_input("fc", "Bias")->AsInput();
46+
PMNode* fc = OpNode("fc", "fc")->AsIntermediate();
47+
PMNode* fc_out =
48+
VarNode("fc_out")->assert_is_op_output("fc", "Out")->AsOutput();
49+
50+
// create topology.
51+
std::vector<PMNode*> fc_inputs{bias, weights, out};
52+
*x >> *flatten_contiguous_range >> *out;
53+
*flatten_contiguous_range >> *xshape;
54+
fc_inputs >> *fc >> *fc_out;
55+
}
56+
57+
void FlattenFcFuser::InsertNewNode(SSAGraph* graph,
58+
const key2nodes_t& matched) {
59+
auto op_desc = GenOpDesc(matched);
60+
auto fc_op = LiteOpRegistry::Global().Create("fc");
61+
auto fc_old = matched.at("fc")->stmt()->op();
62+
auto* scope = fc_old->scope();
63+
auto& valid_places = fc_old->valid_places();
64+
fc_op->Attach(op_desc, scope);
65+
66+
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
67+
68+
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
69+
IR_NODE_LINK_TO(matched.at("weights"), new_op_node);
70+
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
71+
IR_NODE_LINK_TO(new_op_node, matched.at("fc_out"));
72+
}
73+
74+
cpp::OpDesc FlattenFcFuser::GenOpDesc(const key2nodes_t& matched) {
75+
cpp::OpDesc op_desc = *matched.at("fc")->stmt()->op_info();
76+
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
77+
op_desc.SetOutput("Out", {matched.at("fc_out")->arg()->name});
78+
int in_num_col_dim = 1;
79+
op_desc.SetAttr("in_num_col_dims", in_num_col_dim);
80+
return op_desc;
81+
}
82+
83+
} // namespace fusion
84+
} // namespace mir
85+
} // namespace lite
86+
} // namespace paddle
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include "lite/core/mir/pattern_matcher_high_api.h"
20+
21+
namespace paddle {
22+
namespace lite {
23+
namespace mir {
24+
namespace fusion {
25+
26+
class FlattenFcFuser : public FuseBase {
27+
public:
28+
explicit FlattenFcFuser(const std::string& type) {}
29+
void BuildPattern() override;
30+
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
31+
32+
private:
33+
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
34+
};
35+
36+
} // namespace fusion
37+
} // namespace mir
38+
} // namespace lite
39+
} // namespace paddle

lite/core/optimizer.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class Optimizer {
113113
"lite_scale_activation_fuse_pass", //
114114
"lite_elementwise_scale_fuse_pass", //
115115
"lite_instance_norm_activation_fuse_pass", //
116+
"lite_flatten_fc_fuse_pass", //
116117
"lite_fc_prelu_fuse_pass", //
117118
"lite_elementwise_activation_fuse_pass",
118119
"lite_conv_scale_fuse_pass",

0 commit comments

Comments
 (0)