Skip to content

Commit 38eb2ed

Browse files
committed
update avepool specializations
1 parent c1e0ecc commit 38eb2ed

File tree

9 files changed

+726
-2688
lines changed

9 files changed

+726
-2688
lines changed

include/api/mli_krn_avepool_spec_api.h

Lines changed: 46 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -26,84 +26,64 @@ extern "C" {
2626

2727
mli_status mli_krn_avepool_chw_fx16_k2x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
2828
mli_status mli_krn_avepool_chw_fx16_k4x4_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
29-
mli_status mli_krn_avepool_chw_fx16_k5x5_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
30-
mli_status mli_krn_avepool_chw_fx16_k7x7_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
31-
mli_status mli_krn_avepool_chw_fx16_k9x9_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
32-
mli_status mli_krn_avepool_chw_fx16_k4x2_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
33-
mli_status mli_krn_avepool_chw_fx16_k4x4_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
34-
mli_status mli_krn_avepool_chw_fx16_k4x6_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
35-
mli_status mli_krn_avepool_chw_fx16_k4x8_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
36-
mli_status mli_krn_avepool_chw_fx16_k6x2_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
37-
mli_status mli_krn_avepool_chw_fx16_k6x4_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
38-
mli_status mli_krn_avepool_chw_fx16_k6x6_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
39-
mli_status mli_krn_avepool_chw_fx16_k6x8_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
40-
mli_status mli_krn_avepool_chw_fx16_k8x2_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
41-
mli_status mli_krn_avepool_chw_fx16_k8x4_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
42-
mli_status mli_krn_avepool_chw_fx16_k8x6_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
43-
mli_status mli_krn_avepool_chw_fx16_k8x8_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
44-
mli_status mli_krn_avepool_chw_fx16_k4x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
45-
mli_status mli_krn_avepool_chw_fx16_k6x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
46-
mli_status mli_krn_avepool_chw_fx16_k6x4_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
47-
mli_status mli_krn_avepool_chw_fx16_k6x6_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
48-
mli_status mli_krn_avepool_chw_fx16_k6x8_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
49-
mli_status mli_krn_avepool_chw_fx16_k8x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
50-
mli_status mli_krn_avepool_chw_fx16_k8x4_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
51-
mli_status mli_krn_avepool_chw_fx16_k8x6_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
52-
mli_status mli_krn_avepool_chw_fx16_k8x8_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
53-
mli_status mli_krn_avepool_chw_fx16_k3x3_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
54-
mli_status mli_krn_avepool_chw_fx16_k5x5_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
55-
mli_status mli_krn_avepool_chw_fx16_k7x7_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
56-
mli_status mli_krn_avepool_chw_fx16_k9x9_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
57-
mli_status mli_krn_avepool_chw_fx16_k2x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
5829
mli_status mli_krn_avepool_chw_fx16_k3x3_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
59-
mli_status mli_krn_avepool_chw_fx16_k4x4_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
6030
mli_status mli_krn_avepool_chw_fx16_k5x5_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
61-
mli_status mli_krn_avepool_chw_fx16_k6x6_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
6231
mli_status mli_krn_avepool_chw_fx16_k7x7_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
63-
mli_status mli_krn_avepool_chw_fx16_k8x8_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
6432
mli_status mli_krn_avepool_chw_fx16_k9x9_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
65-
mli_status mli_krn_avepool_chw_fx16_k10x10_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
33+
mli_status mli_krn_avepool_chw_fx16_k2x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
34+
mli_status mli_krn_avepool_chw_fx16_k4x4_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
35+
mli_status mli_krn_avepool_chw_fx16_k6x6_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
36+
mli_status mli_krn_avepool_chw_fx16_k8x8_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
37+
mli_status mli_krn_avepool_chw_fx16_k3x3_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
38+
mli_status mli_krn_avepool_chw_fx16_k5x5_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
39+
mli_status mli_krn_avepool_chw_fx16_k7x7_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
40+
mli_status mli_krn_avepool_chw_fx16_k9x9_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
41+
mli_status mli_krn_avepool_chw_fx16_k4x4_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
42+
mli_status mli_krn_avepool_chw_fx16_k6x6_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
43+
mli_status mli_krn_avepool_chw_fx16_k8x8_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
44+
mli_status mli_krn_avepool_chw_fx16_k1xn_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
45+
mli_status mli_krn_avepool_chw_fx16_k1x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
46+
mli_status mli_krn_avepool_chw_fx16_k1x3_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
47+
mli_status mli_krn_avepool_chw_fx16_knx1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
48+
mli_status mli_krn_avepool_chw_fx16_k2x1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
49+
mli_status mli_krn_avepool_chw_fx16_k3x1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
50+
mli_status mli_krn_avepool_chw_fx16_k1xn_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
51+
mli_status mli_krn_avepool_chw_fx16_k1x2_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
52+
mli_status mli_krn_avepool_chw_fx16_k1x3_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
53+
mli_status mli_krn_avepool_chw_fx16_knx1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
54+
mli_status mli_krn_avepool_chw_fx16_k2x1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
55+
mli_status mli_krn_avepool_chw_fx16_k3x1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
6656
mli_status mli_krn_avepool_chw_fx16_generic(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
6757

6858
mli_status mli_krn_avepool_chw_fx8_k2x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
6959
mli_status mli_krn_avepool_chw_fx8_k4x4_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
70-
mli_status mli_krn_avepool_chw_fx8_k5x5_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
71-
mli_status mli_krn_avepool_chw_fx8_k7x7_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
72-
mli_status mli_krn_avepool_chw_fx8_k9x9_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
73-
mli_status mli_krn_avepool_chw_fx8_k4x2_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
74-
mli_status mli_krn_avepool_chw_fx8_k4x4_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
75-
mli_status mli_krn_avepool_chw_fx8_k4x6_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
76-
mli_status mli_krn_avepool_chw_fx8_k4x8_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
77-
mli_status mli_krn_avepool_chw_fx8_k6x2_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
78-
mli_status mli_krn_avepool_chw_fx8_k6x4_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
79-
mli_status mli_krn_avepool_chw_fx8_k6x6_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
80-
mli_status mli_krn_avepool_chw_fx8_k6x8_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
81-
mli_status mli_krn_avepool_chw_fx8_k8x2_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
82-
mli_status mli_krn_avepool_chw_fx8_k8x4_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
83-
mli_status mli_krn_avepool_chw_fx8_k8x6_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
84-
mli_status mli_krn_avepool_chw_fx8_k8x8_str1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
85-
mli_status mli_krn_avepool_chw_fx8_k4x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
86-
mli_status mli_krn_avepool_chw_fx8_k6x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
87-
mli_status mli_krn_avepool_chw_fx8_k6x4_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
88-
mli_status mli_krn_avepool_chw_fx8_k6x6_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
89-
mli_status mli_krn_avepool_chw_fx8_k6x8_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
90-
mli_status mli_krn_avepool_chw_fx8_k8x2_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
91-
mli_status mli_krn_avepool_chw_fx8_k8x4_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
92-
mli_status mli_krn_avepool_chw_fx8_k8x6_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
93-
mli_status mli_krn_avepool_chw_fx8_k8x8_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
94-
mli_status mli_krn_avepool_chw_fx8_k3x3_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
95-
mli_status mli_krn_avepool_chw_fx8_k5x5_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
96-
mli_status mli_krn_avepool_chw_fx8_k7x7_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
97-
mli_status mli_krn_avepool_chw_fx8_k9x9_str1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
98-
mli_status mli_krn_avepool_chw_fx8_k2x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
9960
mli_status mli_krn_avepool_chw_fx8_k3x3_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
100-
mli_status mli_krn_avepool_chw_fx8_k4x4_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
10161
mli_status mli_krn_avepool_chw_fx8_k5x5_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
102-
mli_status mli_krn_avepool_chw_fx8_k6x6_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
10362
mli_status mli_krn_avepool_chw_fx8_k7x7_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
104-
mli_status mli_krn_avepool_chw_fx8_k8x8_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
10563
mli_status mli_krn_avepool_chw_fx8_k9x9_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
106-
mli_status mli_krn_avepool_chw_fx8_k10x10_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
64+
mli_status mli_krn_avepool_chw_fx8_k2x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
65+
mli_status mli_krn_avepool_chw_fx8_k4x4_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
66+
mli_status mli_krn_avepool_chw_fx8_k6x6_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
67+
mli_status mli_krn_avepool_chw_fx8_k8x8_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
68+
mli_status mli_krn_avepool_chw_fx8_k3x3_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
69+
mli_status mli_krn_avepool_chw_fx8_k5x5_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
70+
mli_status mli_krn_avepool_chw_fx8_k7x7_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
71+
mli_status mli_krn_avepool_chw_fx8_k9x9_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
72+
mli_status mli_krn_avepool_chw_fx8_k4x4_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
73+
mli_status mli_krn_avepool_chw_fx8_k6x6_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
74+
mli_status mli_krn_avepool_chw_fx8_k8x8_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
75+
mli_status mli_krn_avepool_chw_fx8_k1xn_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
76+
mli_status mli_krn_avepool_chw_fx8_k1x2_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
77+
mli_status mli_krn_avepool_chw_fx8_k1x3_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
78+
mli_status mli_krn_avepool_chw_fx8_knx1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
79+
mli_status mli_krn_avepool_chw_fx8_k2x1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
80+
mli_status mli_krn_avepool_chw_fx8_k3x1_krnpad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
81+
mli_status mli_krn_avepool_chw_fx8_k1xn_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
82+
mli_status mli_krn_avepool_chw_fx8_k1x2_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
83+
mli_status mli_krn_avepool_chw_fx8_k1x3_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
84+
mli_status mli_krn_avepool_chw_fx8_knx1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
85+
mli_status mli_krn_avepool_chw_fx8_k2x1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
86+
mli_status mli_krn_avepool_chw_fx8_k3x1_nopad(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
10787
mli_status mli_krn_avepool_chw_fx8_generic(const mli_tensor * in, const mli_pool_cfg * cfg, mli_tensor * out);
10888

10989
#ifdef __cplusplus

lib/gen/codegen.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,13 @@ def print_if_list(self, func_list, hierargy_list, default_func):
111111
for func in sorted_list:
112112
if func.generic:
113113
continue #call to generic should be done at the end
114-
string += sep + "if ("
114+
string += sep + "if "
115115
cond = func.print_condition()
116116
if (len(cond) <= func.max_len_of_line):
117117
string += cond
118118
else:
119119
string += func.print_condition(split=True)
120-
string += ") {\n"
120+
string += " {\n"
121121
string += " "*8 + "return " + func.print_call()
122122
string += " "*4 + "}"
123123
sep = " else "

lib/gen/func.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def print_padding_condition(self, split=False):
140140
def print_condition(self, split=False):
141141
indent = ""
142142
newline = ""
143+
cond_count = 0
143144
if (split):
144145
indent = " "*12
145146
newline = "\n"
@@ -148,22 +149,32 @@ def print_condition(self, split=False):
148149
if (self.stride_w > 0):
149150
cond += sep + "(stride_w == " + str(self.stride_w) + ")"
150151
sep = " && "
152+
cond_count += 1
151153
if (self.stride_h > 0):
152154
cond += sep + "(stride_h == " + str(self.stride_h) + ")"
153155
sep = " && "
156+
cond_count += 1
154157
if (self.kernel_w > 0):
155158
cond += sep + newline + indent + "(kernel_w == " + str(self.kernel_w) + ")"
156159
sep = " && "
160+
cond_count += 1
157161
if (self.kernel_h > 0):
158162
cond += sep + "(kernel_h == " + str(self.kernel_h) + ")"
159163
sep = " && "
164+
cond_count += 1
160165
if (self.channels > 0):
161166
cond += sep + newline + indent + "(channels == " + str(self.channels) + ")"
162167
sep = " && "
168+
cond_count += 1
163169
if (self.print_padding_condition() != "(1)"):
164170
#skip padding
165171
cond += sep + newline + indent + self.print_padding_condition(split=split)
166-
return cond
172+
cond_count += 1
173+
174+
if (cond_count > 1):
175+
return "(" + cond + ")"
176+
else:
177+
return cond
167178

168179
def get_types(self):
169180
if self.datatype == "fx16":

lib/gen/mli_krn_avepool_gen.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -58,48 +58,38 @@
5858
ch = 0
5959
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad")])
6060

61-
#stride = 1, any kernel size, any channel size
6261
corefunc = "avepool_chw_krnpad"
63-
stride = 1
64-
kernel_range = range(5, 11, 2)
62+
stride = 0
63+
kernel_range = range(3, 11, 2)
6564
ch = 0
6665
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range])
6766

6867
corefunc = "avepool_chw_krnpad_k4_Nx2_N_even"
69-
stride = 1
70-
width_range = range(4, 9, 2)
71-
height_range = range(2, 9, 2)
68+
stride = 0
69+
kernel_range = range(2, 9, 2)
7270
ch = 0
73-
f_list.extend([Func(fbase, w, h, ch, stride, stride, corefunc, "krnpad") for w in width_range for h in height_range])
71+
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range])
7472

75-
corefunc = "avepool_chw_nopad_k4_Nx2_N_even"
76-
stride = 1
77-
w = 4
78-
h = 2
73+
corefunc = "avepool_chw_nopad"
74+
stride = 0
75+
kernel_range = range(3, 11, 2)
7976
ch = 0
80-
f_list.extend([Func(fbase, w, h, ch, stride, stride, corefunc, "nopad")])
77+
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range])
8178

82-
#stride = 1, any kernel size, any channel size
8379
corefunc = "avepool_chw_nopad_k4_Nx2_N_even"
84-
stride = 1
85-
width_range = range(6, 9, 2)
86-
height_range = range(2, 9, 2)
87-
ch = 0
88-
f_list.extend([Func(fbase, w, h, ch, stride, stride, corefunc, "nopad") for w in width_range for h in height_range])
89-
90-
#stride = 1, any kernel size, any channel size
91-
corefunc = "avepool_chw_nopad"
92-
stride = 1
93-
kernel_range = range(3, 11, 2)
80+
stride = 0
81+
kernel_range = range(4, 9, 2)
9482
ch = 0
9583
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range])
9684

97-
#here construct the specializations for any stride, and multiple kernel sizes > 1
9885
corefunc = "avepool_chw_krnpad"
9986
stride = 0
100-
kernel_range = range(2, 11)
87+
kernel_range = [0, 2, 3]
10188
ch = 0
102-
f_list.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range])
89+
f_list.extend([Func(fbase, 1, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range])
90+
f_list.extend([Func(fbase, k, 1, ch, stride, stride, corefunc, "krnpad") for k in kernel_range])
91+
f_list.extend([Func(fbase, 1, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range])
92+
f_list.extend([Func(fbase, k, 1, ch, stride, stride, corefunc, "nopad") for k in kernel_range])
10393

10494
#at last add the generic function that can be used in the else branch in the wrapper.
10595
corefunc = "avepool_chw_krnpad"
@@ -117,7 +107,7 @@
117107
c.set_wrapper_variables({'padding_left' : "cfg->padding_left"})
118108
c.set_wrapper_variables({'padding_right' : "cfg->padding_right"})
119109
c.set_wrapper_hierarchy(['stride_w', 'stride_h', 'kernel_w', 'kernel_h', 'padding'])
120-
c.set_wrapper_if_tree(True)
110+
c.set_wrapper_if_tree(False)
121111

122112
if "fx16" in sys.argv or no_args:
123113
f = open(output_file_fx16, "wb")

0 commit comments

Comments
 (0)