Skip to content

Commit 025616d

Browse files
committed
fix specialization for small maxpooling kernels
1 parent fc6a332 commit 025616d

File tree

7 files changed

+423
-88
lines changed

7 files changed

+423
-88
lines changed

examples/example_cifar10_caffe/cifar10_model_chw.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,7 +434,7 @@ static void check_result(
434434
//========================================================================================
435435
#if (MODEL_BIT_DEPTH != MODEL_FX_8)
436436
static inline mli_status maxpool_chw(const mli_tensor *in, const mli_pool_cfg *cfg, mli_tensor *out) {
437-
return mli_krn_maxpool_chw_fx16_k3x3(in, cfg, out);
437+
return mli_krn_maxpool_chw_fx16_k3x3_krnpad(in, cfg, out);
438438
}
439439

440440
static inline mli_status avepool_chw(const mli_tensor *in, const mli_pool_cfg *cfg, mli_tensor *out) {
@@ -455,7 +455,7 @@ static inline mli_status mli_krn_permute_fx(const mli_tensor *in, const mli_perm
455455

456456
#else // MODEL_BIT_DEPTH == (MODEL_FX_8W16D || MODEL_FX_8W16D)
457457
static inline mli_status maxpool_chw(const mli_tensor *in, const mli_pool_cfg *cfg, mli_tensor *out) {
458-
return mli_krn_maxpool_chw_fx8_k3x3(in, cfg, out);
458+
return mli_krn_maxpool_chw_fx8_k3x3_krnpad(in, cfg, out);
459459
}
460460

461461
static inline mli_status avepool_chw(const mli_tensor *in, const mli_pool_cfg *cfg, mli_tensor *out) {

include/api/mli_krn_maxpool_spec_api.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ mli_status mli_krn_maxpool_chw_fx16_k2x2(const mli_tensor * in, const mli_pool_c
101101
mli_status mli_krn_maxpool_chw_fx16_k2x2_ch1(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
102102
mli_status mli_krn_maxpool_chw_fx16_k3x3(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
103103
mli_status mli_krn_maxpool_chw_fx16_k3x3_ch1(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
104+
mli_status mli_krn_maxpool_chw_fx16_k2x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
105+
mli_status mli_krn_maxpool_chw_fx16_k3x3_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
104106
mli_status mli_krn_maxpool_chw_fx16_generic(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
105107

106108
mli_status mli_krn_maxpool_chw_fx8_k2x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
@@ -180,6 +182,8 @@ mli_status mli_krn_maxpool_chw_fx8_k2x2(const mli_tensor * in, const mli_pool_cf
180182
mli_status mli_krn_maxpool_chw_fx8_k2x2_ch1(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
181183
mli_status mli_krn_maxpool_chw_fx8_k3x3(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
182184
mli_status mli_krn_maxpool_chw_fx8_k3x3_ch1(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
185+
mli_status mli_krn_maxpool_chw_fx8_k2x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
186+
mli_status mli_krn_maxpool_chw_fx8_k3x3_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
183187
mli_status mli_krn_maxpool_chw_fx8_generic(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
184188

185189
#ifdef __cplusplus

lib/gen/func.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,15 @@ def print_padding_condition(self, split=False):
124124
else:
125125
cond = "(1)"
126126
elif self.padding == "nopad":
127-
cond = "(padding_top == 0) && "
128-
cond += "(padding_bot == 0) && "
129-
cond += "(padding_left == 0) && "
130-
cond += "(padding_right == 0)"
127+
cond = "(padding_top == 0) && "
128+
cond += "(padding_bot == 0) && "
129+
cond += "(padding_left == 0) && "
130+
cond += "(padding_right == 0)"
131+
elif self.padding == "krnpad" and (self.kernel_h > 0) and (self.kernel_w > 0):
132+
cond = "(padding_top <= " + str(int((self.kernel_h - 1) / 2)) + ") && "
133+
cond += "(padding_bot <= " + str(int(self.kernel_h / 2)) + ") && "
134+
cond += "(padding_left <= " + str(int((self.kernel_w -1) / 2)) + ") && "
135+
cond += "(padding_right <= " + str(int(self.kernel_w / 2)) + ")"
131136
else:
132137
cond = "(1)"
133138
return cond

lib/gen/mli_krn_maxpool_gen.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -53,35 +53,41 @@
5353
channel_range = [0,1,3]
5454
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
5555

56-
corefunc = "maxpool_chw_krnpad"
56+
corefunc = "maxpool_chw_pad"
5757
stride = 1
5858
kernel_range = range(4,11)
5959
channel_range = [0,1,3]
6060
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
6161

6262
#stride = 1, 1xk and kx1 versions
63-
corefunc = "maxpool_chw_krnpad"
63+
corefunc = "maxpool_chw_pad"
6464
stride = 1
6565
kernel_range = range(2,4)
6666
channel_range = [0,1]
6767
f_list.extend([Func(fbase, 1, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
6868
f_list.extend([Func(fbase, k, 1, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
6969

7070
#fix single dimension, others flex
71-
corefunc = "maxpool_chw_krnpad"
71+
corefunc = "maxpool_chw_pad"
7272
stride = 1
7373
f_list.extend([Func(fbase, 1, 0, 0, stride, stride, corefunc, "")]) #k_width == 1
7474
f_list.extend([Func(fbase, 0, 1, 0, stride, stride, corefunc, "")]) #k_heigth == 1
7575
f_list.extend([Func(fbase, 0, 0, 1, stride, stride, corefunc, "")]) #channels == 1
7676

77-
corefunc = "maxpool_chw_krnpad_small"
77+
corefunc = "maxpool_chw_pad"
7878
stride = 0
7979
kernel_range = [2,3]
8080
channel_range = [0,1]
8181
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "") for k in kernel_range for ch in channel_range])
8282

83+
corefunc = "maxpool_chw_krnpad_small"
84+
stride = 0
85+
kernel_range = [2,3]
86+
channel_range = [0]
87+
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range for ch in channel_range])
88+
8389
#at last add the generic function that can be used in the else branch in the wrapper.
84-
corefunc = "maxpool_chw_krnpad"
90+
corefunc = "maxpool_chw_pad"
8591
default_func = Func(fbase, 0, 0, 0, 0, 0, corefunc, generic=True)
8692
f_list.append(default_func)
8793

lib/src/kernels/pooling/mli_krn_maxpool_chw.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -379,7 +379,7 @@ static inline void __attribute__((always_inline)) maxpool_chw_small(
379379
}
380380

381381
template <typename io_T>
382-
static inline void __attribute__((always_inline)) maxpool_chw_krnpad(
382+
static inline void __attribute__((always_inline)) maxpool_chw_pad(
383383
const MLI_PTR(io_T) __restrict in_ftrs,
384384
MLI_OUT_PTR(io_T) __restrict out_ftrs,
385385
const int row_beg,

0 commit comments

Comments
 (0)