Skip to content

Commit f3a080b

Browse files
committed
feat: make ResNet-finetune args public and update default settings
- Updated all configuration arguments in `Args` struct to `pub` for external accessibility. - Adjusted default values for `grads_accumulation` and `learning_rate` to improve performance.
1 parent 6a2a99f commit f3a080b

File tree

1 file changed

+14
-14
lines changed
  • examples/resnet_finetune/src

1 file changed

+14
-14
lines changed

examples/resnet_finetune/src/main.rs

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -52,60 +52,60 @@ tracel-ai/models reference:
5252
pub struct Args {
5353
/// Random seed for reproducibility.
5454
#[arg(short, long, default_value = "0")]
55-
seed: u64,
55+
pub seed: u64,
5656

5757
/// Train percentage.
5858
#[arg(long, default_value = "70")]
5959
pub train_percentage: u8,
6060

6161
/// Directory to save the artifacts.
6262
#[arg(long, default_value = "/tmp/resnet_finetune")]
63-
artifact_dir: String,
63+
pub artifact_dir: String,
6464

6565
/// Batch size for processing
6666
#[arg(short, long, default_value_t = 24)]
67-
batch_size: usize,
67+
pub batch_size: usize,
6868

6969
/// Grads accumulation size for processing
70-
#[arg(short, long, default_value_t = 24)]
71-
grads_accumulation: usize,
70+
#[arg(short, long, default_value_t = 8)]
71+
pub grads_accumulation: usize,
7272

7373
/// Category smoothing factor for training.
7474
#[arg(long, default_value = "0.1")]
75-
smoothing: Option<f32>,
75+
pub smoothing: Option<f32>,
7676

7777
/// Number of workers for data loading.
7878
#[arg(long, default_value = "4")]
79-
num_workers: usize,
79+
pub num_workers: usize,
8080

8181
/// Number of epochs to train the model.
8282
#[arg(long, default_value = "200")]
83-
num_epochs: usize,
83+
pub num_epochs: usize,
8484

8585
/// Pretrained Resnet Model.
8686
/// Use "list" to list all available pretrained models.
8787
#[arg(long, default_value = "resnet50.tv_in1k")]
88-
pretrained: String,
88+
pub pretrained: String,
8989

9090
/// Freeze the body layers during training.
9191
#[arg(long, default_value = "false")]
92-
freeze_layers: bool,
92+
pub freeze_layers: bool,
9393

9494
/// Drop Block Prob
9595
#[arg(long, default_value = "0.2")]
96-
drop_block_prob: f64,
96+
pub drop_block_prob: f64,
9797

9898
/// Drop Path Prob
9999
#[arg(long, default_value = "0.05")]
100-
stochastic_depth_prob: f64,
100+
pub stochastic_depth_prob: f64,
101101

102102
/// Learning rate
103-
#[arg(long, default_value_t = 5e-3)]
103+
#[arg(long, default_value_t = 5e-4)]
104104
pub learning_rate: f64,
105105

106106
/// Early stopping patience
107107
#[arg(long, default_value_t = 20)]
108-
patience: usize,
108+
pub patience: usize,
109109

110110
/// Enable cautious weight decay.
111111
#[arg(long, default_value = "false")]

0 commit comments

Comments
 (0)