|
22 | 22 | # fill the basic information |
23 | 23 | #------------------------------------------------------------ |
24 | 24 | func_body_template_file_chw = "mli_krn_conv2d_chw_func_body.txt" |
25 | | -func_body_template_file_hwc = "mli_krn_conv2d_hwc_func_body.txt" |
| 25 | +func_body_template_file_hwc = "mli_krn_conv2d_nhwc_func_body.txt" |
26 | 26 | file_template = "filetemplate.txt" |
27 | 27 | file_header_template = "header_filetemplate.txt" |
28 | 28 | function_group = "Convolution 2d" |
|
31 | 31 | output_file_chw_fx16 = "..\..\lib\src\kernels\convolution\mli_krn_conv2d_chw_fx16.cc" |
32 | 32 | output_file_chw_fx8 = "..\..\lib\src\kernels\convolution\mli_krn_conv2d_chw_fx8.cc" |
33 | 33 | output_file_chw_fx8w16d = "..\..\lib\src\kernels\convolution\mli_krn_conv2d_chw_fx8w16d.cc" |
34 | | -output_file_hwc_sa8_sa8_sa32 = "..\..\lib\src\kernels\convolution\mli_krn_conv2d_hwc_sa8_sa8_sa32.cc" |
| 34 | +output_file_nhwc_sa8_sa8_sa32 = "..\..\lib\src\kernels\convolution\mli_krn_conv2d_nhwc_sa8_sa8_sa32.cc" |
35 | 35 |
|
36 | 36 | f_list_chw_fx16 = [] |
37 | | -f_list_hwc_sa8 = [] |
| 37 | +f_list_nhwc_sa8 = [] |
38 | 38 | f_args = [("const mli_tensor *", "in"), |
39 | 39 | ("const mli_tensor *", "weights"), |
40 | 40 | ("const mli_tensor *", "bias"), |
|
235 | 235 | # Create a list of specialization functions for SA8 SA8 SA32 |
236 | 236 | #------------------------------------------------------------ |
237 | 237 |
|
238 | | -fbase = ("krn", "conv2d", "hwc", "sa8_sa8_sa32", f_args) |
| 238 | +fbase = ("krn", "conv2d", "nhwc", "sa8_sa8_sa32", f_args) |
239 | 239 |
|
240 | | -corefunc = "convolution2D_hwc_krnpad" |
| 240 | +corefunc = "convolution2D_nhwc_krnpad" |
241 | 241 | stride = 0 |
242 | 242 | kernel_range = range(3, 6, 2) |
243 | 243 | ch = 0 |
244 | | -f_list_hwc_sa8.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range]) |
| 244 | +f_list_nhwc_sa8.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "krnpad") for k in kernel_range]) |
245 | 245 |
|
246 | | -corefunc = "convolution2D_hwc_nopad" |
| 246 | +corefunc = "convolution2D_nhwc_nopad" |
247 | 247 | stride = 0 |
248 | 248 | kernel_range = range(3, 6, 2) |
249 | 249 | ch = 0 |
250 | | -f_list_hwc_sa8.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range]) |
| 250 | +f_list_nhwc_sa8.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad") for k in kernel_range]) |
251 | 251 |
|
252 | | -corefunc = "pointwise_convolution2D_hwc_nopad" |
| 252 | +corefunc = "pointwise_convolution2D_nhwc_nopad" |
253 | 253 | stride = 0 |
254 | 254 | k = 1 |
255 | 255 | ch = 0 |
256 | | -f_list_hwc_sa8.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad")]) |
| 256 | +f_list_nhwc_sa8.extend([Func(fbase, k, k, ch, stride, stride, corefunc, "nopad")]) |
257 | 257 |
|
258 | | -corefunc = "convolution2D_hwc_krnpad" |
| 258 | +corefunc = "convolution2D_nhwc_krnpad" |
259 | 259 | default_func_hwc = Func(fbase, 0, 0, 0, 0, 0, corefunc, generic=True) |
260 | | -f_list_hwc_sa8.append(default_func_hwc) |
| 260 | +f_list_nhwc_sa8.append(default_func_hwc) |
261 | 261 |
|
262 | 262 | #------------------------------------------------------------ |
263 | 263 | # Generate the HWC output file |
|
275 | 275 |
|
276 | 276 | if "sa8_sa8_sa32" in sys.argv or no_args: |
277 | 277 | #Create SA8 HWC C output file |
278 | | - f = open(output_file_hwc_sa8_sa8_sa32, "wb") |
279 | | - f.write(c.print_file(f_list_hwc_sa8, default_func_hwc, func_body_template_file_hwc, file_template, include_list_hwc, define_list)) |
| 278 | + f = open(output_file_nhwc_sa8_sa8_sa32, "wb") |
| 279 | + f.write(c.print_file(f_list_nhwc_sa8, default_func_hwc, func_body_template_file_hwc, file_template, include_list_hwc, define_list)) |
280 | 280 | f.close() |
281 | 281 |
|
282 | 282 |
|
|
285 | 285 | #------------------------------------------------------------ |
286 | 286 | if "header" in sys.argv or no_args: |
287 | 287 | fh = open(output_header_file, "wb") |
288 | | - fh.write(c.print_proto_file([f_list_chw_fx16, f_list_chw_fx8, f_list_chw_fx8w16d, f_list_hwc_sa8], function_group, capital_header_file_name, file_header_template)) |
| 288 | + fh.write(c.print_proto_file([f_list_chw_fx16, f_list_chw_fx8, f_list_chw_fx8w16d, f_list_nhwc_sa8], function_group, capital_header_file_name, file_header_template)) |
289 | 289 | fh.close() |
290 | 290 |
|
0 commit comments