Skip to content

Commit f7f9cc2

Browse files
[cherry-pick] fix Int8 conv compute error and support conv+hardswish fusion (#7657)
1 parent 2bcccb3 commit f7f9cc2

13 files changed

+2606
-1139
lines changed

lite/backends/arm/math/conv3x3_winograd_int8.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ void conv_compute_2x2_3x3_int8(const int8_t* input,
122122
(int32_t*)(g_trans_remain_tmp_data + threads * 128); // NOLINT
123123
auto act_type = act_param.active_type;
124124
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
125-
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
125+
float alpha[12] = {
126+
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
126127
bool flag_bias = (bias == nullptr) ? false : true;
127128
if (act_param.has_active) {
128129
if (act_type == lite_api::ActivationType::kRelu) {
@@ -141,6 +142,13 @@ void conv_compute_2x2_3x3_int8(const int8_t* input,
141142
alpha[1] = local_alpha;
142143
alpha[2] = local_alpha;
143144
alpha[3] = local_alpha;
145+
} else if (act_type == lite_api::ActivationType::kHardSwish) {
146+
flag_act = 4;
147+
for (int i = 0; i < 4; i++) {
148+
alpha[i] = 1.f / act_param.hard_swish_scale;
149+
alpha[i + 4] = act_param.hard_swish_offset;
150+
alpha[i + 8] = act_param.hard_swish_threshold;
151+
}
144152
}
145153
}
146154
// begin compute
@@ -435,7 +443,8 @@ void conv_compute_4x4_3x3_int8(const int8_t* input,
435443
g_trans_tmp_output_data + threads * 192;
436444
auto act_type = act_param.active_type;
437445
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
438-
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
446+
float alpha[12] = {
447+
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
439448
bool flag_bias = (bias == nullptr) ? false : true;
440449
if (act_param.has_active) {
441450
if (act_type == lite_api::ActivationType::kRelu) {
@@ -454,6 +463,13 @@ void conv_compute_4x4_3x3_int8(const int8_t* input,
454463
alpha[1] = local_alpha;
455464
alpha[2] = local_alpha;
456465
alpha[3] = local_alpha;
466+
} else if (act_type == lite_api::ActivationType::kHardSwish) {
467+
flag_act = 4;
468+
for (int i = 0; i < 4; i++) {
469+
alpha[i] = 1.f / act_param.hard_swish_scale;
470+
alpha[i + 4] = act_param.hard_swish_offset;
471+
alpha[i + 8] = act_param.hard_swish_threshold;
472+
}
457473
}
458474
}
459475
// begin compute

lite/backends/arm/math/conv3x3s1_direct_int8.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ void conv_3x3s1_direct_int8(const int8_t* din,
4747
auto act_param = param.activation_param;
4848
auto act_type = act_param.active_type;
4949
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
50-
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
50+
float alpha[12] = {
51+
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
5152
if (act_param.has_active) {
5253
if (act_type == lite_api::ActivationType::kRelu) {
5354
flag_act = 1;
@@ -65,6 +66,13 @@ void conv_3x3s1_direct_int8(const int8_t* din,
6566
alpha[1] = local_alpha;
6667
alpha[2] = local_alpha;
6768
alpha[3] = local_alpha;
69+
} else if (act_type == lite_api::ActivationType::kHardSwish) {
70+
flag_act = 4;
71+
for (int i = 0; i < 4; i++) {
72+
alpha[i] = 1.f / act_param.hard_swish_scale;
73+
alpha[i + 4] = act_param.hard_swish_offset;
74+
alpha[i + 8] = act_param.hard_swish_threshold;
75+
}
6876
}
6977
}
7078
int pad_h = paddings[0];

lite/backends/arm/math/conv3x3s2_direct_int8.cc

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
5151
auto act_param = param.activation_param;
5252
auto act_type = act_param.active_type;
5353
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
54-
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
54+
float alpha[12] = {
55+
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
5556
if (act_param.has_active) {
5657
if (act_type == lite_api::ActivationType::kRelu) {
5758
flag_act = 1;
@@ -69,6 +70,13 @@ void conv_3x3s2_direct_int8(const int8_t* din,
6970
alpha[1] = local_alpha;
7071
alpha[2] = local_alpha;
7172
alpha[3] = local_alpha;
73+
} else if (act_type == lite_api::ActivationType::kHardSwish) {
74+
flag_act = 4;
75+
for (int i = 0; i < 4; i++) {
76+
alpha[i] = 1.f / act_param.hard_swish_scale;
77+
alpha[i + 4] = act_param.hard_swish_offset;
78+
alpha[i + 8] = act_param.hard_swish_threshold;
79+
}
7280
}
7381
}
7482
int pad_h = paddings[0];
@@ -503,7 +511,8 @@ void conv_3x3s2_direct_int8(const int8_t* din,
503511
auto act_param = param.activation_param;
504512
auto act_type = act_param.active_type;
505513
int flag_act = 0; // relu: 1, relu6: 2, leakey: 3
506-
float alpha[4] = {0.f, 0.f, 0.f, 0.f};
514+
float alpha[12] = {
515+
0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f};
507516
if (act_param.has_active) {
508517
if (act_type == lite_api::ActivationType::kRelu) {
509518
flag_act = 1;
@@ -521,6 +530,13 @@ void conv_3x3s2_direct_int8(const int8_t* din,
521530
alpha[1] = local_alpha;
522531
alpha[2] = local_alpha;
523532
alpha[3] = local_alpha;
533+
} else if (act_type == lite_api::ActivationType::kHardSwish) {
534+
flag_act = 4;
535+
for (int i = 0; i < 4; i++) {
536+
alpha[i] = 1.f / act_param.hard_swish_scale;
537+
alpha[i + 4] = act_param.hard_swish_offset;
538+
alpha[i + 8] = act_param.hard_swish_threshold;
539+
}
524540
}
525541
}
526542
int pad_h = paddings[0];

0 commit comments

Comments
 (0)