Skip to content

Commit 99c872f

Browse files
authored
FC/matmul(v2) + scale fuse pass (#47420)
1 parent 559b975 commit 99c872f

File tree

8 files changed

+318
-89
lines changed

8 files changed

+318
-89
lines changed

paddle/fluid/framework/ir/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,7 @@ if(WITH_MKLDNN)
218218
pass_library(elt_act_mkldnn_fuse_pass inference DIR mkldnn)
219219
pass_library(matmul_elementwise_add_mkldnn_fuse_pass inference DIR mkldnn)
220220
pass_library(matmul_activation_mkldnn_fuse_pass inference DIR mkldnn)
221+
pass_library(operator_scale_onednn_fuse_pass inference DIR mkldnn)
221222
pass_library(cpu_quantize_placement_pass base DIR mkldnn)
222223
pass_library(cpu_quantize_pass inference DIR mkldnn)
223224
pass_library(cpu_quantize_squash_pass inference DIR mkldnn)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/mkldnn/operator_scale_onednn_fuse_pass.h"
16+
17+
#include "paddle/fluid/framework/op_version_registry.h"
18+
#include "paddle/fluid/platform/mkldnn_reuse.h"
19+
#include "paddle/fluid/string/pretty_log.h"
20+
21+
namespace paddle {
22+
namespace framework {
23+
namespace ir {
24+
25+
using string::PrettyLogDetail;
26+
27+
void FuseOperatorScaleOneDNNPass::ApplyImpl(Graph *graph) const {
28+
const std::vector<std::string> fusable_ops{"fc", "matmul", "matmul_v2"};
29+
for (const auto &op : fusable_ops) FuseScale(graph, op);
30+
}
31+
32+
void FuseOperatorScaleOneDNNPass::FuseScale(Graph *graph,
33+
const std::string &op_type) const {
34+
PADDLE_ENFORCE_NOT_NULL(
35+
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
36+
FusePassBase::Init(op_type + "_scale_onednn_fuse_pass", graph);
37+
38+
GraphPatternDetector gpd;
39+
patterns::OperatorActivation op_scale_pattern(
40+
gpd.mutable_pattern(), op_type + "_scale_onednn_fuse_pass");
41+
op_scale_pattern(op_type, "scale");
42+
43+
int found_operator_scale_count = 0;
44+
45+
auto handler = [&](const GraphPatternDetector::subgraph_t &subgraph,
46+
Graph *g) {
47+
GET_IR_NODE_FROM_SUBGRAPH(operator_op, preceding_op, op_scale_pattern);
48+
GET_IR_NODE_FROM_SUBGRAPH(operator_out, preceding_op_out, op_scale_pattern);
49+
GET_IR_NODE_FROM_SUBGRAPH(scale_op, activation, op_scale_pattern);
50+
GET_IR_NODE_FROM_SUBGRAPH(scale_out, activation_out, op_scale_pattern);
51+
52+
if (operator_op->Op()->HasAttr("use_mkldnn") &&
53+
!(PADDLE_GET_CONST(bool, operator_op->Op()->GetAttr("use_mkldnn")))) {
54+
VLOG(4) << "Only oneDNN version of " << op_type
55+
<< "can be fused with scale.";
56+
return;
57+
}
58+
59+
if (scale_op->Op()->GetAttrIfExists<float>("bias") != 0.0) {
60+
VLOG(4) << op_type << " can be fused only with unbiased scale.";
61+
return;
62+
}
63+
64+
float scale = PADDLE_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
65+
66+
auto *scope = param_scope();
67+
auto const &names = scale_op->Op()->InputNames();
68+
bool has_scale_tensor =
69+
std::find(names.begin(), names.end(), "ScaleTensor") != names.end();
70+
71+
if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) {
72+
std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front();
73+
auto *scale_var = scope->FindVar(scale_var_name);
74+
// ScaleTensor must be weight
75+
if (scale_var == nullptr) return;
76+
auto *scale_tensor = scale_var->GetMutable<LoDTensor>();
77+
scale = *(scale_tensor->data<float>());
78+
}
79+
80+
operator_op->Op()->SetAttr("fused_output_scale", scale);
81+
operator_op->Op()->SetOutput("Out", {scale_out->Name()});
82+
83+
IR_OP_VAR_LINK(operator_op, scale_out);
84+
GraphSafeRemoveNodes(g, {scale_op, operator_out});
85+
found_operator_scale_count++;
86+
};
87+
88+
gpd(graph, handler);
89+
AddStatis(found_operator_scale_count);
90+
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
91+
found_operator_scale_count > 0)
92+
PrettyLogDetail(
93+
"--- fused %d %s with scale", found_operator_scale_count, op_type);
94+
}
95+
96+
} // namespace ir
97+
} // namespace framework
98+
} // namespace paddle
99+
100+
REGISTER_PASS(operator_scale_onednn_fuse_pass,
101+
paddle::framework::ir::FuseOperatorScaleOneDNNPass);
102+
REGISTER_PASS_CAPABILITY(operator_scale_onednn_fuse_pass)
103+
.AddCombination(
104+
paddle::framework::compatible::OpVersionComparatorCombination()
105+
.EQ("fc", 0)
106+
.LE("matmul", 1)
107+
.EQ("matmul_v2", 0)
108+
.EQ("scale", 0));
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
18+
#include "paddle/fluid/framework/ir/graph.h"
19+
20+
namespace paddle {
21+
namespace framework {
22+
namespace ir {
23+
24+
class FuseOperatorScaleOneDNNPass : public FusePassBase {
25+
public:
26+
virtual ~FuseOperatorScaleOneDNNPass() {}
27+
28+
protected:
29+
void ApplyImpl(Graph *graph) const override;
30+
31+
void FuseScale(Graph *graph, const std::string &op_type) const;
32+
};
33+
34+
} // namespace ir
35+
} // namespace framework
36+
} // namespace paddle

paddle/fluid/inference/api/paddle_pass_builder.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,7 @@ void CpuPassStrategy::EnableMKLDNN() {
326326
"softplus_activation_mkldnn_fuse_pass", //
327327
"shuffle_channel_mkldnn_detect_pass", //
328328
"elt_act_mkldnn_fuse_pass", //
329+
"operator_scale_onednn_fuse_pass", //
329330
// TODO(intel): Please fix the bug on windows.
330331
// https://github.com/PaddlePaddle/Paddle/issues/29710
331332
// "mkldnn_inplace_pass", // This pass should be activated after
@@ -419,6 +420,7 @@ void CpuPassStrategy::EnableMkldnnInt8() {
419420
passes_.push_back("scale_matmul_fuse_pass");
420421
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
421422
passes_.push_back("matmul_elementwise_add_mkldnn_fuse_pass");
423+
passes_.push_back("operator_scale_onednn_fuse_pass");
422424
passes_.push_back("cpu_quantize_placement_pass");
423425
passes_.push_back("cpu_quantize_pass");
424426
passes_.push_back("cpu_quantize_squash_pass");

paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,12 @@ class FCPrimitiveFactory {
533533
scale, dnnl::algorithm::eltwise_hardswish, alpha, beta);
534534
}
535535

536+
if (ctx.HasAttr("fused_output_scale")) {
537+
float scale_alpha = ctx.Attr<float>("fused_output_scale");
538+
post_operations.append_eltwise(
539+
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
540+
}
541+
536542
attributes.set_post_ops(post_operations);
537543
return attributes;
538544
}

paddle/fluid/platform/mkldnn_reuse.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,12 @@ class MatMulV2MKLDNNHandler
250250

251251
AppendActivation(ctx, post_operations);
252252

253+
if (ctx.HasAttr("fused_output_scale")) {
254+
float scale_alpha = ctx.Attr<float>("fused_output_scale");
255+
post_operations.append_eltwise(
256+
1.0, dnnl::algorithm::eltwise_linear, scale_alpha, 0.0f);
257+
}
258+
253259
matmul_attrs.set_post_ops(post_operations);
254260
return matmul_attrs;
255261
}

python/paddle/fluid/tests/unittests/ir/inference/test_mkldnn_matmul_activation_fuse_pass.py

Lines changed: 80 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121

2222

2323
class TestMatmulActivationMkldnnFusePass(PassAutoScanTest):
24-
2524
def sample_program_config(self, draw):
2625
transpose_X = draw(st.booleans())
2726
transpose_Y = draw(st.booleans())
@@ -30,11 +29,25 @@ def sample_program_config(self, draw):
3029
channel = draw(st.sampled_from([8]))
3130
input_dim = draw(st.sampled_from([32]))
3231
activation_type = draw(
33-
st.sampled_from([
34-
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
35-
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
36-
'leaky_relu'
37-
]))
32+
st.sampled_from(
33+
[
34+
'relu',
35+
'gelu',
36+
'swish',
37+
'mish',
38+
'sqrt',
39+
'hard_swish',
40+
'sigmoid',
41+
'abs',
42+
'relu6',
43+
'clip',
44+
'tanh',
45+
'hard_sigmoid',
46+
'leaky_relu',
47+
'scale',
48+
]
49+
)
50+
)
3851

3952
def generate_input(type):
4053
if transpose_X and transpose_Y:
@@ -55,50 +68,60 @@ def generate_input(type):
5568
else:
5669
return np.random.random(shape_y).astype(np.float32)
5770

58-
matmul_op = OpConfig(type='matmul',
59-
inputs={
60-
'X': ['matmul_X'],
61-
'Y': ['matmul_Y']
62-
},
63-
outputs={'Out': ['matmul_output']},
64-
attrs={
65-
'transpose_X': transpose_X,
66-
'transpose_Y': transpose_Y,
67-
'alpha': alpha
68-
})
71+
matmul_op = OpConfig(
72+
type='matmul',
73+
inputs={'X': ['matmul_X'], 'Y': ['matmul_Y']},
74+
outputs={'Out': ['matmul_output']},
75+
attrs={
76+
'transpose_X': transpose_X,
77+
'transpose_Y': transpose_Y,
78+
'alpha': alpha,
79+
'use_mkldnn': True,
80+
},
81+
)
6982

7083
if activation_type == "relu6":
71-
activation_op = OpConfig(activation_type,
72-
inputs={"X": ["matmul_output"]},
73-
outputs={"Out": ["activation_output"]},
74-
threshold=draw(
75-
st.floats(min_value=1.0,
76-
max_value=10.0)))
84+
activation_op = OpConfig(
85+
activation_type,
86+
inputs={"X": ["matmul_output"]},
87+
outputs={"Out": ["activation_output"]},
88+
threshold=draw(st.floats(min_value=1.0, max_value=10.0)),
89+
)
7790
elif activation_type == "leaky_relu":
78-
activation_op = OpConfig(activation_type,
79-
inputs={"X": ["matmul_output"]},
80-
outputs={"Out": ["activation_output"]},
81-
alpha=draw(
82-
st.floats(min_value=0.1,
83-
max_value=1.0)))
91+
activation_op = OpConfig(
92+
activation_type,
93+
inputs={"X": ["matmul_output"]},
94+
outputs={"Out": ["activation_output"]},
95+
alpha=draw(st.floats(min_value=0.1, max_value=1.0)),
96+
)
97+
elif activation_type == "scale":
98+
activation_op = OpConfig(
99+
activation_type,
100+
inputs={"X": ["matmul_output"]},
101+
outputs={"Out": ["activation_output"]},
102+
scale=draw(st.sampled_from([0.125, 0.4, 0.875, 2])),
103+
)
84104
elif activation_type == "swish":
85-
activation_op = OpConfig(activation_type,
86-
inputs={"X": ["matmul_output"]},
87-
outputs={"Out": ["activation_output"]},
88-
beta=draw(
89-
st.floats(min_value=0.1,
90-
max_value=1.0)))
105+
activation_op = OpConfig(
106+
activation_type,
107+
inputs={"X": ["matmul_output"]},
108+
outputs={"Out": ["activation_output"]},
109+
beta=draw(st.floats(min_value=0.1, max_value=1.0)),
110+
)
91111
elif activation_type == "clip":
92112
activation_op = OpConfig(
93113
activation_type,
94114
inputs={"X": ["matmul_output"]},
95115
outputs={"Out": ["activation_output"]},
96116
min=draw(st.floats(min_value=0.1, max_value=0.49)),
97-
max=draw(st.floats(min_value=0.5, max_value=1.0)))
117+
max=draw(st.floats(min_value=0.5, max_value=1.0)),
118+
)
98119
else:
99-
activation_op = OpConfig(activation_type,
100-
inputs={"X": ["matmul_output"]},
101-
outputs={"Out": ["activation_output"]})
120+
activation_op = OpConfig(
121+
activation_type,
122+
inputs={"X": ["matmul_output"]},
123+
outputs={"Out": ["activation_output"]},
124+
)
102125

103126
model_net = [matmul_op, activation_op]
104127

@@ -107,20 +130,32 @@ def generate_input(type):
107130
weights={},
108131
inputs={
109132
'matmul_X': TensorConfig(data_gen=partial(generate_input, 'x')),
110-
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y'))
133+
'matmul_Y': TensorConfig(data_gen=partial(generate_input, 'y')),
111134
},
112-
outputs=['activation_output'])
135+
outputs=['activation_output'],
136+
)
113137

114138
return program_config
115139

116140
def sample_predictor_configs(self, program_config):
117-
config = self.create_inference_config(use_mkldnn=True)
141+
config = self.create_inference_config(
142+
use_mkldnn=True,
143+
passes=[
144+
'matmul_activation_mkldnn_fuse_pass',
145+
'operator_scale_onednn_fuse_pass',
146+
],
147+
)
118148
yield config, ['matmul'], (1e-5, 1e-5)
119149

120150
def test(self):
121-
self.run_and_statis(quant=False,
122-
max_examples=30,
123-
passes=['matmul_activation_mkldnn_fuse_pass'])
151+
self.run_and_statis(
152+
quant=False,
153+
max_examples=50,
154+
passes=[
155+
'matmul_activation_mkldnn_fuse_pass',
156+
'operator_scale_onednn_fuse_pass',
157+
],
158+
)
124159

125160

126161
if __name__ == '__main__':

0 commit comments

Comments
 (0)