@@ -90,12 +90,9 @@ pub struct Args {
9090 #[ arg( long, default_value = "60" ) ]
9191 num_epochs : usize ,
9292
93- /// Resnet Model Config
94- #[ arg( long, default_value = "resnet34" ) ]
95- resnet_prefab : String ,
96-
97- /// Resnet Pretrained
98- #[ arg( long, default_value = "tv_in1k" ) ]
93+ /// Pretrained Resnet Model.
94+ /// Use "list" to list all available pretrained models.
95+ #[ arg( long, default_value = "resnet34.tv_in1k" ) ]
9996 resnet_pretrained : String ,
10097
10198 /// Drop Block Prob
@@ -164,6 +161,26 @@ fn main() -> anyhow::Result<()> {
164161pub fn train < B : AutodiffBackend > ( args : & Args ) -> anyhow:: Result < ( ) > {
165162 let device: B :: Device = Default :: default ( ) ;
166163
164+ // TODO: lift to clap parser.
165+ if args. resnet_pretrained == "list" {
166+ println ! ( "Available pretrained models:" ) ;
167+ for prefab in PREFAB_RESNET_MAP . items {
168+ if let Some ( weights) = prefab. weights {
169+ for item in weights. items {
170+ println ! ( "- \" {}.{}\" : {}" , prefab. name, item. name, item. description) ;
171+ }
172+ }
173+ }
174+ return Ok ( ( ) ) ;
175+ }
176+ let [ resnet_prefab, resnet_pretrained] = args
177+ . resnet_pretrained
178+ . splitn ( 2 , "." )
179+ . map ( |s| s. to_string ( ) )
180+ . collect :: < Vec < String > > ( )
181+ . try_into ( )
182+ . unwrap ( ) ;
183+
167184 // Remove existing artifacts before to get an accurate learner summary
168185 let artifact_dir: & str = args. artifact_dir . as_ref ( ) ;
169186 std:: fs:: remove_dir_all ( artifact_dir) ;
@@ -173,10 +190,10 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
173190
174191 let disk_cache = DiskCacheConfig :: default ( ) ;
175192
176- let prefab = PREFAB_RESNET_MAP . expect_lookup_prefab ( & args . resnet_prefab ) ;
193+ let prefab = PREFAB_RESNET_MAP . expect_lookup_prefab ( & resnet_prefab) ;
177194
178195 let weights = prefab
179- . expect_lookup_pretrained_weights ( & args . resnet_pretrained )
196+ . expect_lookup_pretrained_weights ( & resnet_pretrained)
180197 . fetch_weights ( & disk_cache)
181198 . expect ( "Failed to fetch pretrained weights" ) ;
182199
@@ -201,8 +218,8 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
201218 train_percentage : args. train_percentage ,
202219 batch_size : args. batch_size ,
203220 num_epochs : args. num_epochs ,
204- resnet_prefab : args . resnet_prefab . clone ( ) ,
205- resnet_pretrained : args . resnet_pretrained . clone ( ) ,
221+ resnet_prefab : resnet_prefab. clone ( ) ,
222+ resnet_pretrained : resnet_pretrained. clone ( ) ,
206223 drop_block_prob : args. drop_block_prob ,
207224 drop_path_prob : args. drop_path_prob ,
208225 learning_rate : args. learning_rate ,
0 commit comments