Skip to content

Commit c1e0ecc

Browse files
committed
Revisit maxpool specializations
1 parent 8aef628 commit c1e0ecc

File tree

9 files changed

+882
-9112
lines changed

9 files changed

+882
-9112
lines changed

include/api/mli_krn_maxpool_spec_api.h

Lines changed: 56 additions & 152 deletions
Large diffs are not rendered by default.

lib/gen/mli_krn_maxpool_gen.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -47,56 +47,56 @@
4747

4848
corefunc = "maxpool_chw_nopad"
4949
stride = 1
50+
kernel_range = [2]
51+
channel_range = [0]
52+
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range for ch in channel_range])
53+
54+
corefunc = "maxpool_chw_nopad"
55+
stride = 0
5056
kernel_range = range(2,11)
51-
channel_range = [0,1,3]
57+
channel_range = [0]
5258
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range for ch in channel_range])
5359

54-
#stride = 1, 1xk and kx1 versions
60+
#stride = 0, 1xk and kx1 versions
5561
corefunc = "maxpool_chw_nopad"
56-
stride = 1
62+
stride = 0
5763
kernel_range = range(2,4)
58-
channel_range = [0,1]
64+
channel_range = [0]
5965
f_list.extend([Func(fbase, 1, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range for ch in channel_range])
6066
f_list.extend([Func(fbase, k, 1, ch, stride, stride, corefunc, "nopad") for k in kernel_range for ch in channel_range])
6167

6268
corefunc = "maxpool_chw_krnpad_small"
63-
stride = 1
69+
stride = 0
6470
kernel_range = [2, 3]
65-
channel_range = [0,1,3]
71+
channel_range = [0]
6672
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
6773

6874
corefunc = "maxpool_chw_pad"
69-
stride = 1
75+
stride = 0
7076
kernel_range = range(4,11)
71-
channel_range = [0,1,3]
77+
channel_range = [0]
7278
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
7379

74-
#stride = 1, 1xk and kx1 versions
80+
# 1xk and kx1 versions
7581
corefunc = "maxpool_chw_pad"
76-
stride = 1
82+
stride = 0
7783
kernel_range = range(2,4)
78-
channel_range = [0,1]
84+
channel_range = [0]
7985
f_list.extend([Func(fbase, 1, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
8086
f_list.extend([Func(fbase, k, 1, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
8187

8288
#fix single dimension, others flex
8389
corefunc = "maxpool_chw_pad"
84-
stride = 1
90+
stride = 0
8591
f_list.extend([Func(fbase, 1, 0, 0, stride, stride, corefunc, "")]) #k_width == 1
8692
f_list.extend([Func(fbase, 0, 1, 0, stride, stride, corefunc, "")]) #k_heigth == 1
87-
f_list.extend([Func(fbase, 0, 0, 1, stride, stride, corefunc, "")]) #channels == 1
8893

8994
corefunc = "maxpool_chw_pad"
9095
stride = 0
9196
kernel_range = [2,3]
92-
channel_range = [0,1]
97+
channel_range = [0]
9398
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "") for k in kernel_range for ch in channel_range])
9499

95-
corefunc = "maxpool_chw_krnpad_small"
96-
stride = 0
97-
kernel_range = [2,3]
98-
channel_range = [0]
99-
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
100100

101101
#at last add the generic function that can be used in the else branch in the wrapper.
102102
corefunc = "maxpool_chw_pad"
@@ -115,7 +115,7 @@
115115
c.set_wrapper_variables({'padding_left' : "cfg->padding_left"})
116116
c.set_wrapper_variables({'padding_right' : "cfg->padding_right"})
117117
c.set_wrapper_hierarchy(['stride_w', 'stride_h', 'kernel_w', 'kernel_h', 'channels', 'padding'])
118-
c.set_wrapper_if_tree(True)
118+
c.set_wrapper_if_tree(False)
119119

120120
if "fx16" in sys.argv or no_args:
121121
f = open(output_file_fx16, "wb")

lib/src/kernels/convolution/mli_krn_conv2d_chw_fx16.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3680,13 +3680,19 @@ mli_status mli_krn_conv2d_chw_fx16(
36803680
return mli_krn_conv2d_chw_fx16_ch1_str1(in, weights, bias, cfg, out);
36813681
} else if ((stride_w == 1) && (stride_h == 1)) {
36823682
return mli_krn_conv2d_chw_fx16_str1(in, weights, bias, cfg, out);
3683-
} else if ((kernel_w == 3) && (kernel_h == 3) && (channels == 1)) {
3683+
} else if (
3684+
(kernel_w == 3) && (kernel_h == 3) &&
3685+
(channels == 1) &&
3686+
(padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
36843687
return mli_krn_conv2d_chw_fx16_k3x3_ch1_krnpad(in, weights, bias, cfg, out);
3685-
} else if ((kernel_w == 3) && (kernel_h == 3)) {
3688+
} else if ((kernel_w == 3) && (kernel_h == 3) && (padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
36863689
return mli_krn_conv2d_chw_fx16_k3x3_krnpad(in, weights, bias, cfg, out);
3687-
} else if ((kernel_w == 2) && (kernel_h == 2) && (channels == 1)) {
3690+
} else if (
3691+
(kernel_w == 2) && (kernel_h == 2) &&
3692+
(channels == 1) &&
3693+
(padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
36883694
return mli_krn_conv2d_chw_fx16_k2x2_ch1_krnpad(in, weights, bias, cfg, out);
3689-
} else if ((kernel_w == 2) && (kernel_h == 2)) {
3695+
} else if ((kernel_w == 2) && (kernel_h == 2) && (padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
36903696
return mli_krn_conv2d_chw_fx16_k2x2_krnpad(in, weights, bias, cfg, out);
36913697
} else if (
36923698
(kernel_w == 1) && (kernel_h == 1) &&

lib/src/kernels/convolution/mli_krn_conv2d_chw_fx8.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3475,13 +3475,19 @@ mli_status mli_krn_conv2d_chw_fx8(
34753475
return mli_krn_conv2d_chw_fx8_ch1_str1(in, weights, bias, cfg, out);
34763476
} else if ((stride_w == 1) && (stride_h == 1)) {
34773477
return mli_krn_conv2d_chw_fx8_str1(in, weights, bias, cfg, out);
3478-
} else if ((kernel_w == 3) && (kernel_h == 3) && (channels == 1)) {
3478+
} else if (
3479+
(kernel_w == 3) && (kernel_h == 3) &&
3480+
(channels == 1) &&
3481+
(padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
34793482
return mli_krn_conv2d_chw_fx8_k3x3_ch1_krnpad(in, weights, bias, cfg, out);
3480-
} else if ((kernel_w == 3) && (kernel_h == 3)) {
3483+
} else if ((kernel_w == 3) && (kernel_h == 3) && (padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
34813484
return mli_krn_conv2d_chw_fx8_k3x3_krnpad(in, weights, bias, cfg, out);
3482-
} else if ((kernel_w == 2) && (kernel_h == 2) && (channels == 1)) {
3485+
} else if (
3486+
(kernel_w == 2) && (kernel_h == 2) &&
3487+
(channels == 1) &&
3488+
(padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
34833489
return mli_krn_conv2d_chw_fx8_k2x2_ch1_krnpad(in, weights, bias, cfg, out);
3484-
} else if ((kernel_w == 2) && (kernel_h == 2)) {
3490+
} else if ((kernel_w == 2) && (kernel_h == 2) && (padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
34853491
return mli_krn_conv2d_chw_fx8_k2x2_krnpad(in, weights, bias, cfg, out);
34863492
} else if (
34873493
(kernel_w == 1) && (kernel_h == 1) &&

lib/src/kernels/convolution/mli_krn_conv2d_chw_fx8w16d.cc

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3475,13 +3475,19 @@ mli_status mli_krn_conv2d_chw_fx8w16d(
34753475
return mli_krn_conv2d_chw_fx8w16d_ch1_str1(in, weights, bias, cfg, out);
34763476
} else if ((stride_w == 1) && (stride_h == 1)) {
34773477
return mli_krn_conv2d_chw_fx8w16d_str1(in, weights, bias, cfg, out);
3478-
} else if ((kernel_w == 3) && (kernel_h == 3) && (channels == 1)) {
3478+
} else if (
3479+
(kernel_w == 3) && (kernel_h == 3) &&
3480+
(channels == 1) &&
3481+
(padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
34793482
return mli_krn_conv2d_chw_fx8w16d_k3x3_ch1_krnpad(in, weights, bias, cfg, out);
3480-
} else if ((kernel_w == 3) && (kernel_h == 3)) {
3483+
} else if ((kernel_w == 3) && (kernel_h == 3) && (padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
34813484
return mli_krn_conv2d_chw_fx8w16d_k3x3_krnpad(in, weights, bias, cfg, out);
3482-
} else if ((kernel_w == 2) && (kernel_h == 2) && (channels == 1)) {
3485+
} else if (
3486+
(kernel_w == 2) && (kernel_h == 2) &&
3487+
(channels == 1) &&
3488+
(padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
34833489
return mli_krn_conv2d_chw_fx8w16d_k2x2_ch1_krnpad(in, weights, bias, cfg, out);
3484-
} else if ((kernel_w == 2) && (kernel_h == 2)) {
3490+
} else if ((kernel_w == 2) && (kernel_h == 2) && (padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
34853491
return mli_krn_conv2d_chw_fx8w16d_k2x2_krnpad(in, weights, bias, cfg, out);
34863492
} else if (
34873493
(kernel_w == 1) && (kernel_h == 1) &&

lib/src/kernels/pooling/mli_krn_avepool_chw_fx16.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,72 +3292,90 @@ mli_status mli_krn_avepool_chw_fx16(const mli_tensor * in, const mli_pool_cfg *
32923292
{
32933293
if (kernel_w == 10) {
32943294
if (kernel_h == 10) {
3295-
{
3295+
if ((padding_top <= 4) && (padding_bot <= 5) && (padding_left <= 4) && (padding_right <= 5)) {
32963296
return mli_krn_avepool_chw_fx16_k10x10_krnpad(in, cfg, out);
3297+
} else {
3298+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
32973299
}
32983300
} else {
32993301
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33003302
}
33013303
} else if (kernel_w == 9) {
33023304
if (kernel_h == 9) {
3303-
{
3305+
if ((padding_top <= 4) && (padding_bot <= 4) && (padding_left <= 4) && (padding_right <= 4)) {
33043306
return mli_krn_avepool_chw_fx16_k9x9_krnpad(in, cfg, out);
3307+
} else {
3308+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33053309
}
33063310
} else {
33073311
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33083312
}
33093313
} else if (kernel_w == 8) {
33103314
if (kernel_h == 8) {
3311-
{
3315+
if ((padding_top <= 3) && (padding_bot <= 4) && (padding_left <= 3) && (padding_right <= 4)) {
33123316
return mli_krn_avepool_chw_fx16_k8x8_krnpad(in, cfg, out);
3317+
} else {
3318+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33133319
}
33143320
} else {
33153321
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33163322
}
33173323
} else if (kernel_w == 7) {
33183324
if (kernel_h == 7) {
3319-
{
3325+
if ((padding_top <= 3) && (padding_bot <= 3) && (padding_left <= 3) && (padding_right <= 3)) {
33203326
return mli_krn_avepool_chw_fx16_k7x7_krnpad(in, cfg, out);
3327+
} else {
3328+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33213329
}
33223330
} else {
33233331
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33243332
}
33253333
} else if (kernel_w == 6) {
33263334
if (kernel_h == 6) {
3327-
{
3335+
if ((padding_top <= 2) && (padding_bot <= 3) && (padding_left <= 2) && (padding_right <= 3)) {
33283336
return mli_krn_avepool_chw_fx16_k6x6_krnpad(in, cfg, out);
3337+
} else {
3338+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33293339
}
33303340
} else {
33313341
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33323342
}
33333343
} else if (kernel_w == 5) {
33343344
if (kernel_h == 5) {
3335-
{
3345+
if ((padding_top <= 2) && (padding_bot <= 2) && (padding_left <= 2) && (padding_right <= 2)) {
33363346
return mli_krn_avepool_chw_fx16_k5x5_krnpad(in, cfg, out);
3347+
} else {
3348+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33373349
}
33383350
} else {
33393351
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33403352
}
33413353
} else if (kernel_w == 4) {
33423354
if (kernel_h == 4) {
3343-
{
3355+
if ((padding_top <= 1) && (padding_bot <= 2) && (padding_left <= 1) && (padding_right <= 2)) {
33443356
return mli_krn_avepool_chw_fx16_k4x4_krnpad(in, cfg, out);
3357+
} else {
3358+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33453359
}
33463360
} else {
33473361
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33483362
}
33493363
} else if (kernel_w == 3) {
33503364
if (kernel_h == 3) {
3351-
{
3365+
if ((padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
33523366
return mli_krn_avepool_chw_fx16_k3x3_krnpad(in, cfg, out);
3367+
} else {
3368+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33533369
}
33543370
} else {
33553371
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33563372
}
33573373
} else if (kernel_w == 2) {
33583374
if (kernel_h == 2) {
3359-
{
3375+
if ((padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
33603376
return mli_krn_avepool_chw_fx16_k2x2_krnpad(in, cfg, out);
3377+
} else {
3378+
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);
33613379
}
33623380
} else {
33633381
return mli_krn_avepool_chw_fx16_generic(in, cfg, out);

lib/src/kernels/pooling/mli_krn_avepool_chw_fx8.cc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3292,72 +3292,90 @@ mli_status mli_krn_avepool_chw_fx8(const mli_tensor * in, const mli_pool_cfg * c
32923292
{
32933293
if (kernel_w == 10) {
32943294
if (kernel_h == 10) {
3295-
{
3295+
if ((padding_top <= 4) && (padding_bot <= 5) && (padding_left <= 4) && (padding_right <= 5)) {
32963296
return mli_krn_avepool_chw_fx8_k10x10_krnpad(in, cfg, out);
3297+
} else {
3298+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
32973299
}
32983300
} else {
32993301
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33003302
}
33013303
} else if (kernel_w == 9) {
33023304
if (kernel_h == 9) {
3303-
{
3305+
if ((padding_top <= 4) && (padding_bot <= 4) && (padding_left <= 4) && (padding_right <= 4)) {
33043306
return mli_krn_avepool_chw_fx8_k9x9_krnpad(in, cfg, out);
3307+
} else {
3308+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33053309
}
33063310
} else {
33073311
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33083312
}
33093313
} else if (kernel_w == 8) {
33103314
if (kernel_h == 8) {
3311-
{
3315+
if ((padding_top <= 3) && (padding_bot <= 4) && (padding_left <= 3) && (padding_right <= 4)) {
33123316
return mli_krn_avepool_chw_fx8_k8x8_krnpad(in, cfg, out);
3317+
} else {
3318+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33133319
}
33143320
} else {
33153321
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33163322
}
33173323
} else if (kernel_w == 7) {
33183324
if (kernel_h == 7) {
3319-
{
3325+
if ((padding_top <= 3) && (padding_bot <= 3) && (padding_left <= 3) && (padding_right <= 3)) {
33203326
return mli_krn_avepool_chw_fx8_k7x7_krnpad(in, cfg, out);
3327+
} else {
3328+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33213329
}
33223330
} else {
33233331
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33243332
}
33253333
} else if (kernel_w == 6) {
33263334
if (kernel_h == 6) {
3327-
{
3335+
if ((padding_top <= 2) && (padding_bot <= 3) && (padding_left <= 2) && (padding_right <= 3)) {
33283336
return mli_krn_avepool_chw_fx8_k6x6_krnpad(in, cfg, out);
3337+
} else {
3338+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33293339
}
33303340
} else {
33313341
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33323342
}
33333343
} else if (kernel_w == 5) {
33343344
if (kernel_h == 5) {
3335-
{
3345+
if ((padding_top <= 2) && (padding_bot <= 2) && (padding_left <= 2) && (padding_right <= 2)) {
33363346
return mli_krn_avepool_chw_fx8_k5x5_krnpad(in, cfg, out);
3347+
} else {
3348+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33373349
}
33383350
} else {
33393351
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33403352
}
33413353
} else if (kernel_w == 4) {
33423354
if (kernel_h == 4) {
3343-
{
3355+
if ((padding_top <= 1) && (padding_bot <= 2) && (padding_left <= 1) && (padding_right <= 2)) {
33443356
return mli_krn_avepool_chw_fx8_k4x4_krnpad(in, cfg, out);
3357+
} else {
3358+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33453359
}
33463360
} else {
33473361
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33483362
}
33493363
} else if (kernel_w == 3) {
33503364
if (kernel_h == 3) {
3351-
{
3365+
if ((padding_top <= 1) && (padding_bot <= 1) && (padding_left <= 1) && (padding_right <= 1)) {
33523366
return mli_krn_avepool_chw_fx8_k3x3_krnpad(in, cfg, out);
3367+
} else {
3368+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33533369
}
33543370
} else {
33553371
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33563372
}
33573373
} else if (kernel_w == 2) {
33583374
if (kernel_h == 2) {
3359-
{
3375+
if ((padding_top <= 0) && (padding_bot <= 1) && (padding_left <= 0) && (padding_right <= 1)) {
33603376
return mli_krn_avepool_chw_fx8_k2x2_krnpad(in, cfg, out);
3377+
} else {
3378+
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);
33613379
}
33623380
} else {
33633381
return mli_krn_avepool_chw_fx8_generic(in, cfg, out);

0 commit comments

Comments
 (0)