11from conv_utils import ConvConfig
22
33
4+ def unet_sweep (op : str , input_dtype : str , output_dtype : str ) -> list [ConvConfig ]:
5+ configs = []
6+ for B in [1 , 2 , 4 , 8 ]:
7+ configs .append (ConvConfig (B , 128 , 128 , 16 , 3 , 3 , 320 , 1 , op , input_dtype , output_dtype ))
8+ configs .append (ConvConfig (B , 128 , 128 , 320 , 3 , 3 , 320 , 1 , op , input_dtype , output_dtype ))
9+ configs .append (ConvConfig (B , 64 , 64 , 320 , 3 , 3 , 320 , 2 , op , input_dtype , output_dtype ))
10+ configs .append (ConvConfig (B , 64 , 64 , 320 , 3 , 3 , 640 , 1 , op , input_dtype , output_dtype ))
11+ configs .append (ConvConfig (B , 64 , 64 , 640 , 3 , 3 , 640 , 1 , op , input_dtype , output_dtype ))
12+ configs .append (ConvConfig (B , 64 , 64 , 320 , 1 , 1 , 640 , 1 , op , input_dtype , output_dtype ))
13+ configs .append (ConvConfig (B , 32 , 32 , 640 , 3 , 3 , 640 , 2 , op , input_dtype , output_dtype ))
14+ configs .append (ConvConfig (B , 32 , 32 , 640 , 3 , 3 , 1280 , 1 , op , input_dtype , output_dtype ))
15+ configs .append (ConvConfig (B , 32 , 32 , 1280 , 3 , 3 , 1280 , 1 , op , input_dtype , output_dtype ))
16+ configs .append (ConvConfig (B , 32 , 32 , 640 , 1 , 1 , 1280 , 1 , op , input_dtype , output_dtype ))
17+ configs .append (ConvConfig (B , 32 , 32 , 2560 , 3 , 3 , 1280 , 1 , op , input_dtype , output_dtype ))
18+ configs .append (ConvConfig (B , 32 , 32 , 2560 , 1 , 1 , 1280 , 1 , op , input_dtype , output_dtype ))
19+ configs .append (ConvConfig (B , 32 , 32 , 1920 , 3 , 3 , 1280 , 1 , op , input_dtype , output_dtype ))
20+ configs .append (ConvConfig (B , 32 , 32 , 1920 , 1 , 1 , 1280 , 1 , op , input_dtype , output_dtype ))
21+ configs .append (ConvConfig (B , 64 , 64 , 1280 , 3 , 3 , 1280 , 1 , op , input_dtype , output_dtype ))
22+ configs .append (ConvConfig (B , 64 , 64 , 1920 , 3 , 3 , 640 , 1 , op , input_dtype , output_dtype ))
23+ configs .append (ConvConfig (B , 64 , 64 , 1920 , 1 , 1 , 640 , 1 , op , input_dtype , output_dtype ))
24+ configs .append (ConvConfig (B , 64 , 64 , 1280 , 3 , 3 , 640 , 1 , op , input_dtype , output_dtype ))
25+ configs .append (ConvConfig (B , 64 , 64 , 1280 , 1 , 1 , 640 , 1 , op , input_dtype , output_dtype ))
26+ configs .append (ConvConfig (B , 64 , 64 , 960 , 3 , 3 , 640 , 1 , op , input_dtype , output_dtype ))
27+ configs .append (ConvConfig (B , 64 , 64 , 960 , 1 , 1 , 640 , 1 , op , input_dtype , output_dtype ))
28+ configs .append (ConvConfig (B , 128 , 128 , 640 , 3 , 3 , 640 , 1 , op , input_dtype , output_dtype ))
29+ configs .append (ConvConfig (B , 128 , 128 , 960 , 3 , 3 , 320 , 1 , op , input_dtype , output_dtype ))
30+ configs .append (ConvConfig (B , 128 , 128 , 960 , 1 , 1 , 320 , 1 , op , input_dtype , output_dtype ))
31+ configs .append (ConvConfig (B , 128 , 128 , 640 , 3 , 3 , 320 , 1 , op , input_dtype , output_dtype ))
32+ configs .append (ConvConfig (B , 128 , 128 , 640 , 1 , 1 , 320 , 1 , op , input_dtype , output_dtype ))
33+ configs .append (ConvConfig (B , 128 , 128 , 320 , 3 , 3 , 16 , 1 , op , input_dtype , output_dtype ))
34+ return configs
35+
436def resnet_sweep (op : str , input_dtype : str , output_dtype : str ) -> list [ConvConfig ]:
537 configs = []
6- for B in [1 , 2 , 4 , 8 , 16 , 32 , 48 ]:
38+ for B in [1 , 2 , 4 , 8 ]:
739 configs .append (ConvConfig (B , 112 , 112 , 64 , 7 , 7 , 3 , 2 , op , input_dtype , output_dtype ))
840 configs .append (ConvConfig (B , 56 , 56 , 64 , 3 , 3 , 64 , 1 , op , input_dtype , output_dtype ))
941 configs .append (ConvConfig (B , 28 , 28 , 128 , 3 , 3 , 128 , 2 , op , input_dtype , output_dtype ))
@@ -19,9 +51,40 @@ def resnet_sweep(op: str, input_dtype: str, output_dtype: str) -> list[ConvConfi
1951
2052def get_conv_configs () -> list [tuple [str , ConvConfig ]]:
2153 configs : list [tuple [str , ConvConfig ]] = []
22- resnet_configs = resnet_sweep ("conv_2d_nchw_fchw" , "f32" , "f32" )
23- resnet_configs += resnet_sweep ("conv_2d_nhwc_hwcf_q" , "i8" , "i32" )
2454
25- configs += [("resnet_sweep" , x ) for x in resnet_configs ]
55+ # Resnet
56+ resnet_configs = []
57+ resnet_configs += resnet_sweep ("conv_2d_nhwc_hwcf" , "f16" , "f32" )
58+ resnet_configs += resnet_sweep ("conv_2d_nhwc_hwcf" , "i8" , "i32" )
59+ resnet_configs += resnet_sweep ("conv_2d_nchw_fchw" , "f16" , "f32" )
60+ resnet_configs += resnet_sweep ("conv_2d_nchw_fchw" , "i8" , "i32" )
61+ configs += [("resnet" , x ) for x in resnet_configs ]
62+
63+ # Unet
64+ unet_configs = []
65+ unet_configs += unet_sweep ("conv_2d_nhwc_hwcf" , "f16" , "f32" )
66+ unet_configs += unet_sweep ("conv_2d_nhwc_hwcf" , "i8" , "i32" )
67+ unet_configs += unet_sweep ("conv_2d_nchw_fchw" , "f16" , "f32" )
68+ unet_configs += unet_sweep ("conv_2d_nchw_fchw" , "i8" , "i32" )
69+ configs += [("unet" , x ) for x in unet_configs ]
70+
71+ return configs
72+
73+ # Test function to run only a few chosen shapes
74+ def get_conv_test_configs () -> list [tuple [str , ConvConfig ]]:
75+ configs : list [tuple [str , ConvConfig ]] = []
76+
77+ resnet_configs = []
78+ # resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "f16", "f32")
79+ # resnet_configs += resnet_sweep("conv_2d_nhwc_hwcf", "i8", "i32")
80+ # resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "f16", "f32")
81+ # resnet_configs += resnet_sweep("conv_2d_nchw_fchw", "i8", "i32")
82+ configs += [("resnet" , x ) for x in resnet_configs ]
83+
84+ unet_configs = []
85+ # unet_configs.append(ConvConfig(1,128,128,16,3,3,320,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))
86+ # unet_configs.append(ConvConfig(1,32,32,640,1,1,1280,1, "conv_2d_nhwc_hwcf_q", "i8", "i32"))
87+
88+ configs += [("unet" , x ) for x in unet_configs ]
2689
2790 return configs
0 commit comments