Skip to content

Commit 83441bc

Browse files
committed
feat: add activation function replacement and update ResNet-finetune defaults
- Introduced `--replace-activation` argument to enable activation function replacement (GELU, PReLU, LeakyReLU) in ResNet configuration. - Updated default values for `num_epochs`, `drop_block_prob`, `learning_rate`, and `weight_decay` for enhanced training stability. - Added `patience` argument for specifying early stopping tolerance. - Included `serde` and `clap` features for improved argument handling and serialization.
1 parent b2f91ff commit 83441bc

File tree

1 file changed

+86
-51
lines changed
  • examples/resnet_finetune/src

1 file changed

+86
-51
lines changed

examples/resnet_finetune/src/main.rs

Lines changed: 86 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ use burn::data::dataset::transform::ShuffledDataset;
1616
use burn::data::dataset::vision::ImageFolderDataset;
1717
use burn::lr_scheduler::cosine::CosineAnnealingLrSchedulerConfig;
1818
use burn::module::Module;
19+
use burn::nn::activation::ActivationConfig;
1920
use burn::nn::loss::BinaryCrossEntropyLossConfig;
21+
use burn::nn::{LeakyReluConfig, PReluConfig};
2022
use burn::optim::AdamWConfig;
2123
use burn::prelude::{Int, Tensor};
2224
use burn::record::CompactRecorder;
@@ -33,8 +35,9 @@ use burn::train::{
3335
LearnerBuilder, LearningStrategy, MetricEarlyStoppingStrategy, MultiLabelClassificationOutput,
3436
StoppingCondition, TrainOutput, TrainStep, ValidStep,
3537
};
36-
use clap::{Parser, arg};
38+
use clap::{Parser, ValueEnum, arg};
3739
use core::clone::Clone;
40+
use serde::{Deserialize, Serialize};
3841
use std::time::Instant;
3942
/*
4043
tracel-ai/models reference:
@@ -46,6 +49,13 @@ tracel-ai/models reference:
4649
| Valid | Loss | 0.168 | 3 | 0.512 | 1 |
4750
*/
4851

52+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, ValueEnum)]
53+
pub enum ReplaceActivationOption {
54+
Gelu,
55+
PRelu,
56+
LeakyRelu,
57+
}
58+
4959
#[derive(Parser, Debug)]
5060
#[command(author, version, about, long_about = None)]
5161
pub struct Args {
@@ -78,40 +88,44 @@ pub struct Args {
7888
pub num_workers: usize,
7989

8090
/// Number of epochs to train the model.
81-
#[arg(long, default_value = "200")]
91+
#[arg(long, default_value = "60")]
8292
pub num_epochs: usize,
8393

94+
/// Early stopping patience
95+
#[arg(long, default_value_t = 20)]
96+
pub patience: usize,
97+
8498
/// Pretrained Resnet Model.
8599
/// Use "list" to list all available pretrained models.
86100
#[arg(long, default_value = "resnet50.tv_in1k")]
87101
pub pretrained: String,
88102

103+
/// Replace activation function?
104+
#[arg(long, default_value = "None")]
105+
pub replace_activation: Option<ReplaceActivationOption>,
106+
89107
/// Freeze the body layers during training.
90108
#[arg(long, default_value = "false")]
91109
pub freeze_layers: bool,
92110

93111
/// Drop Block Prob
94-
#[arg(long, default_value = "0.15")]
112+
#[arg(long, default_value = "0.2")]
95113
pub drop_block_prob: f64,
96114

97115
/// Drop Path Prob
98116
#[arg(long, default_value = "0.05")]
99117
pub stochastic_depth_prob: f64,
100118

101119
/// Learning rate
102-
#[arg(long, default_value_t = 5e-4)]
120+
#[arg(long, default_value_t = 5e-5)]
103121
pub learning_rate: f64,
104122

105-
/// Early stopping patience
106-
#[arg(long, default_value_t = 20)]
107-
pub patience: usize,
108-
109123
/// Enable cautious weight decay.
110124
#[arg(long, default_value = "false")]
111125
pub cautious_weight_decay: bool,
112126

113127
/// Optimizer Weight decay.
114-
#[arg(long, default_value_t = 5e-3)]
128+
#[arg(long, default_value_t = 2e-2)]
115129
pub weight_decay: f32,
116130
}
117131

@@ -212,7 +226,23 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
212226
.fetch_weights(&disk_cache)
213227
.expect("Failed to fetch pretrained weights");
214228

215-
let resnet_config = prefab.to_config();
229+
let mut resnet_config = prefab.to_config();
230+
231+
if let Some(option) = &args.replace_activation {
232+
match option {
233+
ReplaceActivationOption::Gelu => {
234+
resnet_config = resnet_config.with_activation(ActivationConfig::Gelu);
235+
}
236+
ReplaceActivationOption::PRelu => {
237+
resnet_config =
238+
resnet_config.with_activation(ActivationConfig::PRelu(PReluConfig::new()));
239+
}
240+
ReplaceActivationOption::LeakyRelu => {
241+
resnet_config = resnet_config
242+
.with_activation(ActivationConfig::LeakyRelu(LeakyReluConfig::new()));
243+
}
244+
}
245+
}
216246

217247
let mut model: ResNet<B> = resnet_config
218248
.clone()
@@ -283,50 +313,55 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
283313
.init()
284314
.expect("Failed to initialize learning rate scheduler");
285315

286-
// Learner config
287-
let learner = LearnerBuilder::new(artifact_dir)
288-
.metric_train_numeric(HammingScore::new())
289-
.metric_valid_numeric(HammingScore::new())
290-
.metric_train_numeric(LossMetric::new())
291-
.metric_valid_numeric(LossMetric::new())
292-
.metric_train(CudaMetric::new())
293-
.metric_valid(CudaMetric::new())
294-
.metric_train_numeric(CpuUse::new())
295-
.metric_valid_numeric(CpuUse::new())
296-
.metric_train_numeric(CpuMemory::new())
297-
.metric_valid_numeric(CpuMemory::new())
298-
.metric_train_numeric(LearningRateMetric::new())
299-
.with_file_checkpointer(CompactRecorder::new())
300-
.early_stopping(MetricEarlyStoppingStrategy::new(
301-
&LossMetric::<B>::new(),
302-
Aggregate::Mean,
303-
Direction::Lowest,
304-
Split::Valid,
305-
StoppingCondition::NoImprovementSince {
306-
n_epochs: args.patience,
307-
},
308-
))
309-
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
310-
.grads_accumulation(args.grads_accumulation)
311-
.num_epochs(args.num_epochs)
312-
.summary()
313-
/*
314-
.renderer(CustomRenderer {})
315-
.with_application_logger(None)
316-
*/
317-
.build(host, optimizer, lr_scheduler);
318-
319-
// Training
320-
let now = Instant::now();
321-
let model_trained = learner.fit(dataloader_train, dataloader_test);
316+
let now: Instant;
317+
{
318+
// Learner config
319+
let learner = LearnerBuilder::new(artifact_dir)
320+
.metric_train_numeric(HammingScore::new())
321+
.metric_valid_numeric(HammingScore::new())
322+
.metric_train_numeric(LossMetric::new())
323+
.metric_valid_numeric(LossMetric::new())
324+
.metric_train(CudaMetric::new())
325+
.metric_valid(CudaMetric::new())
326+
.metric_train_numeric(CpuUse::new())
327+
.metric_valid_numeric(CpuUse::new())
328+
.metric_train_numeric(CpuMemory::new())
329+
.metric_valid_numeric(CpuMemory::new())
330+
.metric_train_numeric(LearningRateMetric::new())
331+
.with_file_checkpointer(CompactRecorder::new())
332+
.early_stopping(MetricEarlyStoppingStrategy::new(
333+
&LossMetric::<B>::new(),
334+
Aggregate::Mean,
335+
Direction::Lowest,
336+
Split::Valid,
337+
StoppingCondition::NoImprovementSince {
338+
n_epochs: args.patience,
339+
},
340+
))
341+
.learning_strategy(LearningStrategy::SingleDevice(device.clone()))
342+
.grads_accumulation(args.grads_accumulation)
343+
.num_epochs(args.num_epochs)
344+
.summary()
345+
/*
346+
.renderer(CustomRenderer {})
347+
.with_application_logger(None)
348+
*/
349+
.build(host, optimizer, lr_scheduler);
350+
351+
// Training
352+
now = Instant::now();
353+
let model_trained = learner.fit(dataloader_train, dataloader_test);
354+
355+
model_trained
356+
.model
357+
.resnet
358+
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
359+
.expect("Trained model should be saved successfully");
360+
}
322361
let elapsed = now.elapsed().as_secs();
323362
println!("Training completed in {}m{}s", (elapsed / 60), elapsed % 60);
324363

325-
model_trained
326-
.model
327-
.resnet
328-
.save_file(format!("{artifact_dir}/model"), &CompactRecorder::new())
329-
.expect("Trained model should be saved successfully");
364+
println!("{:#?}", args);
330365

331366
Ok(())
332367
}

0 commit comments

Comments
 (0)