Skip to content

Commit d0e19af

Browse files
authored
[CHERRY-PICK] Added caching to oneDNN FC and op+unsqueeze2 and op+reshape2 fuse passes (#47690)
* fc cherrypick * another files added * added transpose cherrypick * reverter somebodys fc changes * minor fix * minor fix * cherry-pick of fc+act changes * minor fix * fix
1 parent cf668ab commit d0e19af

15 files changed

+995
-715
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ if(WITH_MKLDNN)
219219
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
220220
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
221221
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
222+
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
223+
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
222224
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
223225
pass_library(cpu_quantize_pass inference DIR mkldnn)
224226
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,44 @@ PDNode *patterns::OperatorActivation::operator()(
958958
return activation_out;
959959
}
960960

961+
PDNode *patterns::OperatorUnsqueeze2::operator()(
962+
const std::string &operator_type, const int num_of_operator_outs) {
963+
auto *preceding_op = pattern->NewNode(preceding_op_repr())
964+
->assert_is_op(operator_type)
965+
->assert_has_n_outputs(num_of_operator_outs);
966+
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
967+
->AsIntermediate()
968+
->assert_is_op_output(operator_type, "Out")
969+
->assert_is_op_input("unsqueeze2");
970+
auto *unsqueeze2_op =
971+
pattern->NewNode(unsqueeze2_op_repr())->assert_is_op("unsqueeze2");
972+
auto *unsqueeze2_out = pattern->NewNode(unsqueeze2_out_repr())
973+
->AsOutput()
974+
->assert_is_op_output("unsqueeze2");
975+
preceding_op->LinksTo({preceding_op_out});
976+
unsqueeze2_op->LinksFrom({preceding_op_out}).LinksTo({unsqueeze2_out});
977+
return unsqueeze2_out;
978+
}
979+
980+
PDNode *patterns::OperatorReshape2::operator()(const std::string &operator_type,
981+
const int num_of_operator_outs) {
982+
auto *preceding_op = pattern->NewNode(preceding_op_repr())
983+
->assert_is_op(operator_type)
984+
->assert_has_n_outputs(num_of_operator_outs);
985+
auto *preceding_op_out = pattern->NewNode(preceding_op_out_repr())
986+
->AsIntermediate()
987+
->assert_is_op_output(operator_type, "Out")
988+
->assert_is_op_input("reshape2");
989+
auto *reshape2_op =
990+
pattern->NewNode(reshape2_op_repr())->assert_is_op("reshape2");
991+
auto *reshape2_out = pattern->NewNode(reshape2_out_repr())
992+
->AsOutput()
993+
->assert_is_op_output("reshape2");
994+
preceding_op->LinksTo({preceding_op_out});
995+
reshape2_op->LinksFrom({preceding_op_out}).LinksTo({reshape2_out});
996+
return reshape2_out;
997+
}
998+
961999
PDNode *patterns::SeqConvEltAddRelu::operator()(
9621000
paddle::framework::ir::PDNode *seqconv_input) {
9631001
// Create Operators

paddle/fluid/framework/ir/graph_pattern_detector.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,32 @@ struct OperatorActivation : public PatternBase {
539539
PATTERN_DECL_NODE(activation_out);
540540
};
541541

542+
struct OperatorUnsqueeze2 : public PatternBase {
543+
OperatorUnsqueeze2(PDPattern* pattern, const std::string& name_scope)
544+
: PatternBase(pattern, name_scope, "operator_unsqueeze2") {}
545+
546+
PDNode* operator()(const std::string& operator_type,
547+
const int num_of_outputs);
548+
549+
PATTERN_DECL_NODE(preceding_op);
550+
PATTERN_DECL_NODE(preceding_op_out);
551+
PATTERN_DECL_NODE(unsqueeze2_op);
552+
PATTERN_DECL_NODE(unsqueeze2_out);
553+
};
554+
555+
struct OperatorReshape2 : public PatternBase {
556+
OperatorReshape2(PDPattern* pattern, const std::string& name_scope)
557+
: PatternBase(pattern, name_scope, "operator_reshape2") {}
558+
559+
PDNode* operator()(const std::string& operator_type,
560+
const int num_of_outputs);
561+
562+
PATTERN_DECL_NODE(preceding_op);
563+
PATTERN_DECL_NODE(preceding_op_out);
564+
PATTERN_DECL_NODE(reshape2_op);
565+
PATTERN_DECL_NODE(reshape2_out);
566+
};
567+
542568
// SEQCONV with Elementwise_Add ReLU
543569
// op: seqconv + elementwise_add + relu
544570
// named nodes:

paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.cc

Lines changed: 34 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
22
//
33
// Licensed under the Apache License, Version 2.0 (the "License");
44
// you may not use this file except in compliance with the License.
@@ -14,9 +14,8 @@
1414

1515
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
1616

17-
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
1817
#include "paddle/fluid/framework/op_version_registry.h"
19-
#include "paddle/fluid/platform/enforce.h"
18+
#include "paddle/fluid/platform/mkldnn_reuse.h"
2019
#include "paddle/fluid/string/pretty_log.h"
2120

2221
namespace paddle {
@@ -26,20 +25,20 @@ namespace ir {
2625
using string::PrettyLogDetail;
2726

2827
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
29-
std::vector<std::string> act_types = {
30-
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
28+
auto act_types = paddle::platform::GetSupportedActivations();
3129

32-
for (std::string act_type : act_types) FuseFCAct(graph, act_type);
30+
for (auto act_type : act_types) FuseFCAct(graph, act_type);
3331
}
3432

3533
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
3634
const std::string &act_type) const {
3735
PADDLE_ENFORCE_NOT_NULL(
3836
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
39-
FusePassBase::Init("fc_act", graph);
37+
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
4038

4139
GraphPatternDetector gpd;
42-
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act");
40+
patterns::OperatorActivation fc_act_pattern(
41+
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
4342
fc_act_pattern("fc", act_type);
4443

4544
int found_fc_act_count = 0;
@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
6261
"is used."));
6362
}
6463

64+
auto attr_map = paddle::platform::GetAttributeMap(act_type);
65+
for (const auto &attr : attr_map) {
66+
if (act_op->HasAttr(attr.first)) {
67+
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
68+
}
69+
}
70+
6571
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
66-
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"));
67-
std::string type = approximate ? "_tanh" : "_erf";
68-
fc_op->SetAttr("activation_type", act_type + type);
72+
std::string gelu_act_type =
73+
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
74+
: "gelu_erf";
75+
fc_op->SetAttr("fuse_activation", gelu_act_type);
6976
} else {
70-
fc_op->SetAttr("activation_type", act_type);
77+
fc_op->SetAttr("fuse_activation", act_type);
7178
}
72-
fc_op->SetAttr("use_mkldnn", true);
7379

80+
fc_op->SetAttr("use_mkldnn", true);
7481
fc_op->SetOutput("Out", {act_out->Name()});
7582

7683
IR_OP_VAR_LINK(fc, act_out);
@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
8087

8188
gpd(graph, handler);
8289
AddStatis(found_fc_act_count);
83-
if (!Has("disable_logs") || !Get<bool>("disable_logs"))
90+
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
91+
found_fc_act_count > 0)
8492
PrettyLogDetail(
8593
"--- fused %d fc with %s activation", found_fc_act_count, act_type);
8694
}
@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
95103
.AddCombination(
96104
paddle::framework::compatible::OpVersionComparatorCombination()
97105
.LE("fc", 0)
98-
.LE("gelu", 0)
99-
.LE("sigmoid", 0)
100-
.LE("mish", 1)
106+
.EQ("abs", 0)
107+
.LE("clip", 1)
108+
.EQ("gelu", 0)
109+
.EQ("hard_sigmoid", 0)
101110
.LE("hard_swish", 0)
102-
.LE("tanh", 0));
111+
.LE("leaky_relu", 1)
112+
.LE("mish", 1)
113+
.EQ("relu", 0)
114+
.EQ("relu6", 0)
115+
.EQ("sigmoid", 0)
116+
.EQ("sqrt", 0)
117+
.EQ("swish", 0)
118+
.EQ("tanh", 0));

paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,14 @@ namespace paddle {
2323
namespace framework {
2424
namespace ir {
2525

26-
/*
27-
* \brief Fuse the FC and activation operators into single OneDNN's
28-
* FC with post-op.
29-
*
30-
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
31-
* as an activation function.
32-
*/
3326
class FuseFCActOneDNNPass : public FusePassBase {
3427
public:
3528
virtual ~FuseFCActOneDNNPass() {}
3629

3730
protected:
38-
void ApplyImpl(ir::Graph *graph) const override;
31+
void ApplyImpl(Graph *graph) const override;
3932

40-
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const;
33+
void FuseFCAct(Graph *graph, const std::string &act_types) const;
4134
};
4235

4336
} // namespace ir

paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass_tester.cc

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
7878
const auto* op = node->Op();
7979
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
8080
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
81-
ASSERT_TRUE(op->HasAttr("activation_type"));
81+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
8282
auto act_type =
83-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
83+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
8484
EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
8585
}
8686
}
@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
113113
const auto* op = node->Op();
114114
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
115115
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
116-
ASSERT_TRUE(op->HasAttr("activation_type"));
116+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
117117
auto act_type =
118-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
118+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
119119
EXPECT_EQ(act_type.compare("gelu_erf"), 0);
120120
}
121121
}
@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
146146
const auto* op = node->Op();
147147
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
148148
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
149-
ASSERT_TRUE(op->HasAttr("activation_type"));
149+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
150150
auto act_type =
151-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
151+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
152152
EXPECT_EQ(act_type.compare("gelu"), 0);
153153
}
154154
}
@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
179179
const auto* op = node->Op();
180180
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
181181
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
182-
ASSERT_TRUE(op->HasAttr("activation_type"));
182+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
183183
auto act_type =
184-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
184+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
185185
EXPECT_EQ(act_type.compare("tanh"), 0);
186186
}
187187
}
@@ -213,9 +213,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
213213
const auto* op = node->Op();
214214
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
215215
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
216-
ASSERT_TRUE(op->HasAttr("activation_type"));
216+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
217217
auto act_type =
218-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
218+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
219219
EXPECT_EQ(act_type.compare("sigmoid"), 0);
220220
}
221221
}
@@ -246,9 +246,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
246246
const auto* op = node->Op();
247247
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
248248
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
249-
ASSERT_TRUE(op->HasAttr("activation_type"));
249+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
250250
auto act_type =
251-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
251+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
252252
EXPECT_EQ(act_type.compare("mish"), 0);
253253
}
254254
}
@@ -280,9 +280,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
280280
const auto* op = node->Op();
281281
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
282282
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
283-
ASSERT_TRUE(op->HasAttr("activation_type"));
283+
ASSERT_TRUE(op->HasAttr("fuse_activation"));
284284
auto act_type =
285-
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
285+
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
286286
EXPECT_EQ(act_type.compare("hard_swish"), 0);
287287
}
288288
}

0 commit comments

Comments
 (0)