@@ -52,60 +52,60 @@ tracel-ai/models reference:
5252pub struct Args {
5353 /// Random seed for reproducibility.
5454 #[ arg( short, long, default_value = "0" ) ]
55- seed : u64 ,
55+ pub seed : u64 ,
5656
5757 /// Train percentage.
5858 #[ arg( long, default_value = "70" ) ]
5959 pub train_percentage : u8 ,
6060
6161 /// Directory to save the artifacts.
6262 #[ arg( long, default_value = "/tmp/resnet_finetune" ) ]
63- artifact_dir : String ,
63+ pub artifact_dir : String ,
6464
6565 /// Batch size for processing
6666 #[ arg( short, long, default_value_t = 24 ) ]
67- batch_size : usize ,
67+ pub batch_size : usize ,
6868
6969 /// Grads accumulation size for processing
70- #[ arg( short, long, default_value_t = 24 ) ]
71- grads_accumulation : usize ,
70+ #[ arg( short, long, default_value_t = 8 ) ]
71+ pub grads_accumulation : usize ,
7272
7373 /// Category smoothing factor for training.
7474 #[ arg( long, default_value = "0.1" ) ]
75- smoothing : Option < f32 > ,
75+ pub smoothing : Option < f32 > ,
7676
7777 /// Number of workers for data loading.
7878 #[ arg( long, default_value = "4" ) ]
79- num_workers : usize ,
79+ pub num_workers : usize ,
8080
8181 /// Number of epochs to train the model.
8282 #[ arg( long, default_value = "200" ) ]
83- num_epochs : usize ,
83+ pub num_epochs : usize ,
8484
8585 /// Pretrained Resnet Model.
8686 /// Use "list" to list all available pretrained models.
8787 #[ arg( long, default_value = "resnet50.tv_in1k" ) ]
88- pretrained : String ,
88+ pub pretrained : String ,
8989
9090 /// Freeze the body layers during training.
9191 #[ arg( long, default_value = "false" ) ]
92- freeze_layers : bool ,
92+ pub freeze_layers : bool ,
9393
9494 /// Drop Block Prob
9595 #[ arg( long, default_value = "0.2" ) ]
96- drop_block_prob : f64 ,
96+ pub drop_block_prob : f64 ,
9797
9898 /// Drop Path Prob
9999 #[ arg( long, default_value = "0.05" ) ]
100- stochastic_depth_prob : f64 ,
100+ pub stochastic_depth_prob : f64 ,
101101
102102 /// Learning rate
103- #[ arg( long, default_value_t = 5e-3 ) ]
103+ #[ arg( long, default_value_t = 5e-4 ) ]
104104 pub learning_rate : f64 ,
105105
106106 /// Early stopping patience
107107 #[ arg( long, default_value_t = 20 ) ]
108- patience : usize ,
108+ pub patience : usize ,
109109
110110 /// Enable cautious weight decay.
111111 #[ arg( long, default_value = "false" ) ]
0 commit comments