Skip to content

Commit 6bc3164

Browse files
authored
[NNAdapter] Add bn+conv2d,transpose+inverse_transpose and identity transpose fuser (#9690)
1 parent 663a438 commit 6bc3164

File tree

6 files changed

+479
-0
lines changed

6 files changed

+479
-0
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "core/types.h"
18+
19+
namespace nnadapter {
20+
21+
void FuseBatchNormConv2DIntoConv2D(core::Model *model);
22+
23+
} // namespace nnadapter
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "core/types.h"
18+
19+
namespace nnadapter {
20+
21+
/*
22+
* Fuse two mutually inverse operations into an identity transpose operation.
23+
*
24+
* in
25+
* |
26+
* transpose -> perm1[0,2,3,1]
27+
* |
28+
* transpose -> perm2[0,3,1,2]
29+
* |
30+
* out
31+
*
32+
* After applied:
33+
*
34+
* in
35+
* |
36+
* transpose -> perm3[0,1,2,3]
37+
* |
38+
* out
39+
*
40+
*/
41+
42+
void FuseTransposeInverseTransposeIntoIdentityTranspose(core::Model *model);
43+
44+
} // namespace nnadapter
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "core/types.h"
18+
19+
namespace nnadapter {
20+
21+
// Remove a transpose operation with a identity permutation, such as perm=[0,1],
22+
// perm=[0,1,2] or perm[0,1,2,3]
23+
void RemoveIdentityTranspose(core::Model *model);
24+
25+
} // namespace nnadapter
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "optimizer/fuse_batch_norm_conv2d_into_conv2d.h"
16+
#include <algorithm>
17+
#include <iostream>
18+
#include <map>
19+
#include <vector>
20+
#include "optimizer/pattern_matcher.h"
21+
#include "utility/debug.h"
22+
#include "utility/logging.h"
23+
#include "utility/micros.h"
24+
#include "utility/modeling.h"
25+
#include "utility/utility.h"
26+
27+
namespace nnadapter {
28+
29+
class BatchNormConv2DFuser : public PatternMatcher {
30+
public:
31+
explicit BatchNormConv2DFuser(NNAdapterOperationType batch_norm_type,
32+
NNAdapterOperationType conv2d_type)
33+
: batch_norm_type_(batch_norm_type), conv2d_type_(conv2d_type) {}
34+
void BuildPattern() override;
35+
bool HandleMatchedResults(core::Model* model,
36+
const std::map<std::string, Node*>& nodes) override;
37+
38+
private:
39+
NNAdapterOperationType batch_norm_type_{NNADAPTER_BATCH_NORMALIZATION};
40+
NNAdapterOperationType conv2d_type_{NNADAPTER_CONV_2D};
41+
};
42+
43+
void BatchNormConv2DFuser::BuildPattern() {
44+
// Operation patterns
45+
auto batch_norm_pattern =
46+
CreatePattern("batch_norm", batch_norm_type_)->IsIntermediate();
47+
auto conv2d_pattern = CreatePattern("conv2d", conv2d_type_);
48+
// Operand patterns
49+
auto batch_norm_input_pattern =
50+
CreatePattern("batch_norm_input")
51+
->IsOperationInputOperand(batch_norm_type_, 0);
52+
auto batch_norm_scale_pattern =
53+
CreatePattern("batch_norm_scale")
54+
->IsOperationInputOperand(batch_norm_type_, 1)
55+
->IsConstantOperand()
56+
->IsIntermediate();
57+
auto batch_norm_bias_pattern =
58+
CreatePattern("batch_norm_bias")
59+
->IsOperationInputOperand(batch_norm_type_, 2)
60+
->IsConstantOperand()
61+
->IsIntermediate();
62+
auto batch_norm_mean_pattern =
63+
CreatePattern("batch_norm_mean")
64+
->IsOperationInputOperand(batch_norm_type_, 3)
65+
->IsConstantOperand()
66+
->IsIntermediate();
67+
auto batch_norm_variance_pattern =
68+
CreatePattern("batch_norm_variance")
69+
->IsOperationInputOperand(batch_norm_type_, 4)
70+
->IsConstantOperand()
71+
->IsIntermediate();
72+
auto batch_norm_epsilon_pattern =
73+
CreatePattern("batch_norm_epsilon")
74+
->IsOperationInputOperand(batch_norm_type_, 5)
75+
->IsConstantOperand()
76+
->IsIntermediate();
77+
auto batch_norm_output_pattern =
78+
CreatePattern("batch_norm_output")
79+
->IsOperationOutputOperand(batch_norm_type_, 0)
80+
->IsOperationInputOperand(conv2d_type_, 0)
81+
->IsIntermediate();
82+
auto conv2d_filter_pattern = CreatePattern("conv2d_filter")
83+
->IsOperationInputOperand(conv2d_type_, 1)
84+
->IsConstantOperand();
85+
auto conv2d_bias_pattern = CreatePattern("conv2d_bias")
86+
->IsOperationInputOperand(conv2d_type_, 2)
87+
->IsConstantOperand();
88+
auto conv2d_output_pattern =
89+
CreatePattern("conv2d_output")->IsOperationOutputOperand(conv2d_type_, 0);
90+
// Create the topological connections for the above patterns
91+
std::vector<Pattern*> batch_norm_input_patterns{batch_norm_input_pattern,
92+
batch_norm_scale_pattern,
93+
batch_norm_bias_pattern,
94+
batch_norm_mean_pattern,
95+
batch_norm_variance_pattern,
96+
batch_norm_epsilon_pattern};
97+
std::vector<Pattern*> conv2d_input_patterns{
98+
batch_norm_output_pattern, conv2d_filter_pattern, conv2d_bias_pattern};
99+
batch_norm_input_patterns >> *batch_norm_pattern >>
100+
*batch_norm_output_pattern;
101+
conv2d_input_patterns >> *conv2d_pattern >> *conv2d_output_pattern;
102+
}
103+
104+
bool BatchNormConv2DFuser::HandleMatchedResults(
105+
core::Model* model, const std::map<std::string, Node*>& nodes) {
106+
// Get the operands and operations from the matched subgraph nodes.
107+
auto batch_norm_operation = nodes.at("batch_norm")->operation;
108+
auto batch_norm_scale_data =
109+
reinterpret_cast<float*>(batch_norm_operation->input_operands[1]->buffer);
110+
auto batch_norm_bias_data =
111+
reinterpret_cast<float*>(batch_norm_operation->input_operands[2]->buffer);
112+
auto batch_norm_mean_data =
113+
reinterpret_cast<float*>(batch_norm_operation->input_operands[3]->buffer);
114+
auto batch_norm_variance_data =
115+
reinterpret_cast<float*>(batch_norm_operation->input_operands[4]->buffer);
116+
auto batch_norm_epsilon = *reinterpret_cast<float*>(
117+
batch_norm_operation->input_operands[5]->buffer);
118+
auto conv2d_operation = nodes.at("conv2d")->operation;
119+
auto conv2d_input_operand = conv2d_operation->input_operands[0];
120+
auto& conv2d_input_type = conv2d_input_operand->type;
121+
auto conv2d_output_operand = conv2d_operation->output_operands[0];
122+
auto& conv2d_output_type = conv2d_output_operand->type;
123+
auto conv2d_filter_operand = conv2d_operation->input_operands[1];
124+
auto& conv2d_filter_type = conv2d_filter_operand->type;
125+
auto conv2d_bias_operand = conv2d_operation->input_operands[2];
126+
auto conv2d_group =
127+
*reinterpret_cast<int32_t*>(conv2d_operation->input_operands[6]->buffer);
128+
auto conv2d_input_channel_size = conv2d_input_type.dimensions.data[1];
129+
NNADAPTER_CHECK_NE(conv2d_input_channel_size, NNADAPTER_UNKNOWN);
130+
auto conv2d_output_channel_size = conv2d_filter_type.dimensions.data[0];
131+
auto conv2d_input_channel_group = conv2d_input_channel_size / conv2d_group;
132+
auto conv2d_output_channel_group = conv2d_output_channel_size / conv2d_group;
133+
auto conv2d_filter_inner_size = conv2d_filter_type.dimensions.data[2] *
134+
conv2d_filter_type.dimensions.data[3];
135+
// The formula for BATCH_NORMALIZATION: output = scale * (input - mean) /
136+
// sqrt(variance + epsilon) + bias
137+
// Equivalent to: output = alpha * input + beta, where alpha = scale /
138+
// sqrt(variance + epsilon), beta = -scale * mean / sqrt(variance + epsilon) +
139+
// bias
140+
std::vector<double> batch_norm_alpha(conv2d_input_channel_size),
141+
batch_norm_beta(conv2d_input_channel_size);
142+
for (int64_t i = 0; i < conv2d_input_channel_size; i++) {
143+
double coeff = batch_norm_scale_data[i] /
144+
std::sqrt(static_cast<double>(batch_norm_variance_data[i]) +
145+
batch_norm_epsilon);
146+
batch_norm_alpha[i] = coeff;
147+
batch_norm_beta[i] =
148+
-batch_norm_mean_data[i] * coeff + batch_norm_bias_data[i];
149+
}
150+
if (IsInt8SymmPerLayerQuantType(conv2d_input_type.precision) &&
151+
(IsInt8SymmPerLayerQuantType(conv2d_filter_type.precision) ||
152+
IsInt8SymmPerChannelQuantType(conv2d_filter_type.precision)) &&
153+
IsInt8SymmPerLayerQuantType(conv2d_output_type.precision)) {
154+
// TODO(hong19860320) Add bn+conv2d fusion for the quantized conv2d
155+
return false;
156+
} else {
157+
NNADAPTER_CHECK_EQ(conv2d_input_type.precision, NNADAPTER_FLOAT32);
158+
NNADAPTER_CHECK_EQ(conv2d_filter_type.precision, NNADAPTER_FLOAT32);
159+
NNADAPTER_CHECK_EQ(conv2d_output_type.precision, NNADAPTER_FLOAT32);
160+
auto conv2d_filter_data =
161+
reinterpret_cast<float*>(conv2d_filter_operand->buffer);
162+
auto conv2d_bias_data =
163+
reinterpret_cast<float*>(conv2d_bias_operand->buffer);
164+
for (int64_t g = 0; g < conv2d_group; g++) {
165+
for (int64_t i = 0; i < conv2d_output_channel_group; i++) {
166+
float sum = 0.0f;
167+
for (int64_t j = 0; j < conv2d_input_channel_group; j++) {
168+
for (int64_t k = 0; k < conv2d_filter_inner_size; k++) {
169+
auto offset =
170+
g * conv2d_output_channel_group * conv2d_input_channel_group *
171+
conv2d_filter_inner_size +
172+
i * conv2d_input_channel_group * conv2d_filter_inner_size +
173+
j * conv2d_filter_inner_size + k;
174+
auto value = conv2d_filter_data[offset];
175+
conv2d_filter_data[offset] =
176+
value * batch_norm_alpha[g * conv2d_input_channel_group + j];
177+
sum += value * batch_norm_beta[g * conv2d_input_channel_group + j];
178+
}
179+
}
180+
conv2d_bias_data[g * conv2d_output_channel_group + i] += sum;
181+
}
182+
}
183+
}
184+
// Replace the input operand the of NNADAPTER_CONV_2D with the input operand
185+
// of NNADAPTER_BATCH_NORMALIZATION
186+
conv2d_operation->input_operands[0] = batch_norm_operation->input_operands[0];
187+
// The matched intermediate operands and operations will be deleted only when
188+
// it returns true.
189+
return true;
190+
}
191+
192+
NNADAPTER_EXPORT void FuseBatchNormConv2DIntoConv2D(core::Model* model) {
193+
for (auto batch_norm_type : {NNADAPTER_BATCH_NORMALIZATION}) {
194+
for (auto conv2d_type : {NNADAPTER_CONV_2D}) {
195+
NNADAPTER_VLOG(5) << "Apply BatchNormConv2DFuser for batch_norm_type:"
196+
<< OperationTypeToString(batch_norm_type)
197+
<< " conv2d_type:"
198+
<< OperationTypeToString(conv2d_type);
199+
bool stop;
200+
do {
201+
BatchNormConv2DFuser batch_norm_conv2d_fuser(batch_norm_type,
202+
conv2d_type);
203+
stop = batch_norm_conv2d_fuser.Apply(model) == 0;
204+
} while (!stop);
205+
}
206+
}
207+
}
208+
209+
} // namespace nnadapter

0 commit comments

Comments
 (0)