Skip to content

Commit ea5f44b

Browse files
authored
[cherry-pick] Squeeze2 and transpose2 fuse using oneDNN(#47712)
* suqeeze2 + transpose2 fuse onednn cherrypick 2.4 * format * fix merge
1 parent 34f67a8 commit ea5f44b

File tree

9 files changed

+217
-7
lines changed

9 files changed

+217
-7
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ if(WITH_MKLDNN)
219219
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
220220
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
221221
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
222+
pass_library(squeeze2_transpose2_onednn_fuse_pass inference DIR mkldnn)
222223
pass_library(operator_unsqueeze2_onednn_fuse_pass inference DIR mkldnn)
223224
pass_library(operator_reshape2_onednn_fuse_pass inference DIR mkldnn)
224225
pass_library(cpu_quantize_placement_pass base DIR mkldnn)

paddle/fluid/framework/ir/graph_pattern_detector.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,25 @@ PDNode *patterns::SeqConvEltAddRelu::operator()(
10421042
return relu_out_var;
10431043
}
10441044

1045+
PDNode *patterns::Squeeze2Transpose2::operator()() {
1046+
auto *squeeze2_op_in = pattern->NewNode(squeeze2_op_in_repr())
1047+
->AsInput()
1048+
->assert_is_op_input("squeeze2", "X");
1049+
auto *squeeze2_op = pattern->NewNode(squeeze2_op_repr())
1050+
->assert_is_op("squeeze2")
1051+
->assert_has_n_outputs(2);
1052+
auto *squeeze2_op_out = pattern->NewNode(squeeze2_op_out_repr())
1053+
->AsIntermediate()
1054+
->assert_is_op_output("squeeze2", "Out")
1055+
->assert_is_op_input("transpose2", "X");
1056+
auto *transpose2_op =
1057+
pattern->NewNode(transpose2_op_repr())->assert_is_op("transpose2");
1058+
1059+
squeeze2_op->LinksFrom({squeeze2_op_in}).LinksTo({squeeze2_op_out});
1060+
transpose2_op->LinksFrom({squeeze2_op_out});
1061+
return transpose2_op;
1062+
}
1063+
10451064
PDNode *patterns::FC::operator()(paddle::framework::ir::PDNode *x,
10461065
bool with_bias,
10471066
bool with_relu) {

paddle/fluid/framework/ir/graph_pattern_detector.h

100755100644
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,20 @@ struct FCMKLDNN : public PatternBase {
634634
PATTERN_DECL_NODE(output);
635635
};
636636

637+
// Squeeze2 + Transpose2
638+
// Forward pass
639+
struct Squeeze2Transpose2 : public PatternBase {
640+
Squeeze2Transpose2(PDPattern* pattern, const std::string& name_scope)
641+
: PatternBase(pattern, name_scope, "squeeze2_transpose2") {}
642+
643+
PDNode* operator()();
644+
645+
PATTERN_DECL_NODE(squeeze2_op_in);
646+
PATTERN_DECL_NODE(squeeze2_op);
647+
PATTERN_DECL_NODE(squeeze2_op_out);
648+
PATTERN_DECL_NODE(transpose2_op);
649+
};
650+
637651
// Embedding
638652
struct Embedding : public PatternBase {
639653
Embedding(PDPattern* pattern, const std::string& name_scope)
@@ -2002,6 +2016,12 @@ struct AddSupportInt8 : public PatternBase {
20022016
out_var->inputs.clear(); \
20032017
out_var->inputs.push_back(op);
20042018

2019+
// Set the in_var as the input of the op
2020+
#define IR_VAR_OP_LINK(in_var, op) \
2021+
in_var->outputs.clear(); \
2022+
in_var->outputs.push_back(op); \
2023+
op->inputs.push_back(in_var);
2024+
20052025
} // namespace ir
20062026
} // namespace framework
20072027
} // namespace paddle
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
#include "paddle/fluid/framework/ir/mkldnn/squeeze2_transpose2_onednn_fuse_pass.h"
15+
#include "paddle/fluid/framework/op_version_registry.h"
16+
#include "paddle/fluid/platform/mkldnn_reuse.h"
17+
#include "paddle/fluid/string/pretty_log.h"
18+
19+
namespace paddle {
20+
namespace framework {
21+
namespace ir {
22+
23+
using string::PrettyLogDetail;
24+
25+
void FuseSqueeze2Transpose2OneDNNPass::ApplyImpl(Graph *graph) const {
26+
PADDLE_ENFORCE_NOT_NULL(graph,
27+
platform::errors::InvalidArgument(
28+
"Pointer to graph argument should not be NULL."));
29+
30+
FusePassBase::Init("squeeze2_transpose2_onednn_fuse_pass", graph);
31+
32+
GraphPatternDetector gpd;
33+
patterns::Squeeze2Transpose2 squeeze2_transpose2_pattern(
34+
gpd.mutable_pattern(), "squeeze2_transpose2_onednn_fuse_pass");
35+
squeeze2_transpose2_pattern();
36+
37+
int found_count = 0;
38+
39+
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
40+
Graph *g) {
41+
GET_IR_NODE_FROM_SUBGRAPH(
42+
squeeze2_op_in, squeeze2_op_in, squeeze2_transpose2_pattern);
43+
GET_IR_NODE_FROM_SUBGRAPH(
44+
squeeze2_op, squeeze2_op, squeeze2_transpose2_pattern);
45+
GET_IR_NODE_FROM_SUBGRAPH(
46+
squeeze2_op_out, squeeze2_op_out, squeeze2_transpose2_pattern);
47+
GET_IR_NODE_FROM_SUBGRAPH(
48+
transpose2_op, transpose2_op, squeeze2_transpose2_pattern);
49+
50+
if (!transpose2_op->Op()->HasAttr("use_mkldnn") ||
51+
(transpose2_op->Op()->HasAttr("use_mkldnn") &&
52+
!(PADDLE_GET_CONST(bool,
53+
transpose2_op->Op()->GetAttr("use_mkldnn"))))) {
54+
VLOG(4) << "Only oneDNN version of transpose2 can be fused after with "
55+
"squeeze2.";
56+
return;
57+
}
58+
59+
std::vector<int> squeeze2_axes =
60+
PADDLE_GET_CONST(std::vector<int>, squeeze2_op->Op()->GetAttr("axes"));
61+
transpose2_op->Op()->SetAttr("fused_squeeze2_axes", squeeze2_axes);
62+
transpose2_op->Op()->SetInput("X", {squeeze2_op_in->Name()});
63+
64+
IR_VAR_OP_LINK(squeeze2_op_in, transpose2_op);
65+
GraphSafeRemoveNodes(g, {squeeze2_op, squeeze2_op_out});
66+
found_count++;
67+
};
68+
69+
gpd(graph, handler);
70+
AddStatis(found_count);
71+
if ((!Has("disable_logs") || !Get<bool>("disable_logs"))) {
72+
PrettyLogDetail("--- fused %d squeeze2 with transpose2", found_count);
73+
}
74+
}
75+
76+
} // namespace ir
77+
} // namespace framework
78+
} // namespace paddle
79+
80+
REGISTER_PASS(squeeze2_transpose2_onednn_fuse_pass,
81+
paddle::framework::ir::FuseSqueeze2Transpose2OneDNNPass);
82+
REGISTER_PASS_CAPABILITY(squeeze2_transpose2_onednn_fuse_pass)
83+
.AddCombination(
84+
paddle::framework::compatible::OpVersionComparatorCombination()
85+
.GE("squeeze2", 0)
86+
.GE("transpose2", 0));
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/graph.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
class FuseSqueeze2Transpose2OneDNNPass : public FusePassBase {
25+
public:
26+
virtual ~FuseSqueeze2Transpose2OneDNNPass() {}
27+
28+
protected:
29+
void ApplyImpl(Graph *graph) const override;
30+
};
31+
32+
} // namespace ir
33+
} // namespace framework
34+
35+
} // namespace paddle

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ void CpuPassStrategy::EnableMKLDNN() {
307307
passes_.insert(passes_.begin(), "mkldnn_placement_pass");
308308

309309
for (auto &pass : std::vector<std::string>({
310+
"squeeze2_transpose2_onednn_fuse_pass",
310311
"depthwise_conv_mkldnn_pass", //
311312
"conv_bn_fuse_pass", // Execute BN passes again to
312313
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
@@ -386,6 +387,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
386387
passes_.push_back("mkldnn_placement_pass");
387388
passes_.push_back("simplify_with_basic_ops_pass");
388389
passes_.push_back("constant_folding_pass");
390+
passes_.push_back("squeeze2_transpose2_onednn_fuse_pass");
389391
passes_.push_back("layer_norm_fuse_pass");
390392
passes_.push_back("attention_lstm_fuse_pass");
391393
passes_.push_back("seqconv_eltadd_relu_fuse_pass");

paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
4242

4343
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
4444

45+
platform::SetInMemDescWithLogicalLayoutFusesSupport(
46+
ctx, const_cast<phi::DenseTensor*>(x), x->mem_desc());
47+
4548
if (ndims == 1) {
4649
framework::TensorCopy(*x, x->place(), out);
4750
out->set_mem_desc(x->mem_desc());

paddle/fluid/operators/transpose_op.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ class TransposeOp : public framework::OperatorWithKernel {
3939
size_t x_rank = x_dims.size();
4040
size_t axis_size = axis.size();
4141

42-
PADDLE_ENFORCE_EQ(x_rank,
42+
// Note: x_rank > axis_size when fuse squeeze2 + transpose2, else x_rank ==
43+
// axis_size
44+
PADDLE_ENFORCE_GE(x_rank,
4345
axis_size,
4446
platform::errors::InvalidArgument(
4547
"The input tensor's dimension "
46-
"should be equal to the axis's size. "
48+
"should be equal to or greater than the axis's size. "
4749
"But received input tensor's dimension is %d, "
4850
"axis's size is %d",
4951
x_rank,

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -120,12 +120,10 @@ static void SetOutMemDescWithUnsqueeze2FuseSupport(
120120
const std::vector<int64_t>& op_tz = out_md.dims();
121121
std::vector<int64_t> unsqueezed_op_tz(
122122
op_tz.size() + fused_unsqueeze2_axes.size(), 0);
123-
124123
for (const auto& axis : fused_unsqueeze2_axes) {
125124
int positive_axis = axis < 0 ? unsqueezed_op_tz.size() + axis : axis;
126125
unsqueezed_op_tz[positive_axis] = 1;
127126
}
128-
129127
int j = 0;
130128
for (size_t i = 0; i < unsqueezed_op_tz.size(); ++i) {
131129
if (unsqueezed_op_tz[i] == 0) {
@@ -143,20 +141,17 @@ static void SetOutMemDescWithReshape2FuseSupport(
143141
std::vector<int64_t> fused_reshape2_shape(
144142
ctx.Attr<std::vector<int>>("fused_reshape2_shape").begin(),
145143
ctx.Attr<std::vector<int>>("fused_reshape2_shape").end());
146-
147144
const int out_shape_numel = out->numel();
148145
const int new_shape_numel = std::accumulate(fused_reshape2_shape.begin(),
149146
fused_reshape2_shape.end(),
150147
1,
151148
std::multiplies<int64_t>());
152-
153149
for (size_t i = 0; i < fused_reshape2_shape.size(); ++i) {
154150
if (fused_reshape2_shape[i] == -1) {
155151
fused_reshape2_shape[i] = -out_shape_numel / new_shape_numel;
156152
break;
157153
}
158154
}
159-
160155
out->set_mem_desc(out_md.reshape(fused_reshape2_shape));
161156
out->Resize(phi::make_ddim(fused_reshape2_shape));
162157
}
@@ -169,11 +164,58 @@ static void SetOutMemDescWithLogicalLayoutFusesSupport(
169164
SetOutMemDescWithUnsqueeze2FuseSupport(ctx, out, out_md);
170165
} else if (ctx.HasAttr("fused_reshape2_shape")) {
171166
SetOutMemDescWithReshape2FuseSupport(ctx, out, out_md);
167+
} else if (ctx.HasAttr("fused_squeeze2_axes")) {
168+
out->set_mem_desc(out_md);
169+
out->Resize(phi::make_ddim(out_md.dims()));
172170
} else {
173171
out->set_mem_desc(out_md);
174172
}
175173
}
176174

175+
static void SetInMemDescWithSqueeze2FuseSupport(
176+
const framework::ExecutionContext& ctx,
177+
phi::DenseTensor* in,
178+
const dnnl::memory::desc& in_md) {
179+
const std::vector<int> fused_squeeze2_axes =
180+
ctx.Attr<std::vector<int>>("fused_squeeze2_axes");
181+
const std::set<int64_t> squeeze2_axes_set(fused_squeeze2_axes.begin(),
182+
fused_squeeze2_axes.end());
183+
const std::vector<int64_t>& x_vec_dims = in_md.dims();
184+
std::vector<int64_t> squeezed_op_tz(
185+
x_vec_dims.size() - fused_squeeze2_axes.size(), 0);
186+
187+
int j = 0;
188+
for (size_t i = 0; i < x_vec_dims.size(); ++i) {
189+
if (squeeze2_axes_set.count(i) ||
190+
squeeze2_axes_set.count(i - x_vec_dims.size())) {
191+
PADDLE_ENFORCE_EQ(
192+
x_vec_dims[i],
193+
1,
194+
platform::errors::InvalidArgument(
195+
"Squeeze2 input '%d' dim should be equal to one, but get '%d'.",
196+
i,
197+
x_vec_dims[i]));
198+
continue;
199+
}
200+
squeezed_op_tz[j++] = x_vec_dims[i];
201+
}
202+
203+
in->set_mem_desc(in_md.reshape(squeezed_op_tz));
204+
in->Resize(phi::make_ddim(squeezed_op_tz));
205+
}
206+
207+
static void SetInMemDescWithLogicalLayoutFusesSupport(
208+
const framework::ExecutionContext& ctx,
209+
phi::DenseTensor* in,
210+
const dnnl::memory::desc& in_md) {
211+
if (ctx.HasAttr("fused_squeeze2_axes")) {
212+
SetInMemDescWithSqueeze2FuseSupport(ctx, in, in_md);
213+
} else {
214+
in->set_mem_desc(in_md);
215+
in->Resize(phi::make_ddim(in_md.dims()));
216+
}
217+
}
218+
177219
template <typename T>
178220
constexpr bool IsInt8() {
179221
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;

0 commit comments

Comments
 (0)