Skip to content

Commit 1f8b83e

Browse files
authored
[Pass] fix conv_bn_fuse_pass when conv's weight is perlayer quantization (#10097)
1 parent 4ee3f6d commit 1f8b83e

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

lite/core/optimizer/mir/fusion/conv_bn_fuser.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,24 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
158158
auto conv_weight_d = conv_weight_t->mutable_data<int8_t>();
159159
// compute new conv_weight for int8
160160
auto weight_scale = conv_op_desc->GetInputScale(weight_name);
161+
std::vector<float> weight_scale_dup(alpha_tensor.numel(), 0);
162+
if (weight_scale.size() == 1) {
163+
for (int i = 0; i < weight_scale_dup.size(); i++) {
164+
weight_scale_dup[i] = weight_scale[0];
165+
}
166+
} else {
167+
for (int i = 0; i < weight_scale_dup.size(); i++) {
168+
weight_scale_dup[i] = weight_scale[i];
169+
}
170+
}
161171
if (conv_type_ == "conv2d_transpose") {
162172
int cout = conv_weight_t->dims()[1] * groups;
163173
int cin_group = conv_weight_t->dims()[0] / groups;
164174
int c_size = cout * conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
165175
int hw = conv_weight_t->dims()[2] * conv_weight_t->dims()[3];
166176
for (int k = 0; k < cin_group; ++k) {
167177
for (int i = 0; i < cout; ++i) {
168-
weight_scale[i] *= fabsf(alpha_data[i]);
178+
weight_scale_dup[i] *= fabsf(alpha_data[i]);
169179
if (alpha_data[i] < 0.f) {
170180
auto ptr_row = conv_weight_d + k * c_size + i * hw;
171181
for (int j = 0; j < hw; ++j) {
@@ -176,7 +186,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
176186
}
177187
} else {
178188
for (int i = 0; i < h; ++i) {
179-
weight_scale[i] *= fabsf(alpha_data[i]);
189+
weight_scale_dup[i] *= fabsf(alpha_data[i]);
180190
if (alpha_data[i] < 0.f) {
181191
auto ptr_row = conv_weight_d + i * w;
182192
for (int j = 0; j < w; ++j) {
@@ -185,7 +195,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
185195
}
186196
}
187197
}
188-
conv_op_desc->SetInputScale(weight_name, weight_scale);
198+
conv_op_desc->SetInputScale(weight_name, weight_scale_dup);
189199
} else if (is_weight_quantization) {
190200
std::string scale_name = conv_weight_name + "_quant_scale";
191201
if (conv_op_desc->HasAttr(scale_name)) {

0 commit comments

Comments
 (0)