Skip to content

Commit 71df906

Browse files
committed
Cleanup finetune example code (some).
1 parent 6a18c55 commit 71df906

File tree

1 file changed

+113
-125
lines changed
  • examples/resnet-finetune/src

1 file changed

+113
-125
lines changed

examples/resnet-finetune/src/main.rs

Lines changed: 113 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use crate::data::{ClassificationBatch, ClassificationBatcher};
99
use crate::dataset::{CLASSES, PlanetLoader, download};
1010
use bimm::cache::disk::DiskCacheConfig;
1111
use bimm::compat::activation_wrapper::ActivationConfig;
12-
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet};
12+
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet, ResNetContractConfig};
1313
use burn::backend::{Autodiff, Cuda};
1414
use burn::config::Config;
1515
use burn::data::dataloader::DataLoaderBuilder;
@@ -62,6 +62,14 @@ $ --drop-path-prob=0.15 --drop-block-prob=0.25 --learning-rate=1e-5
6262
#[derive(Parser, Debug)]
6363
#[command(author, version, about, long_about = None)]
6464
pub struct Args {
65+
/// Random seed for reproducibility.
66+
#[arg(short, long, default_value = "0")]
67+
seed: u64,
68+
69+
/// Train percentage.
70+
#[arg(long, default_value = "70")]
71+
pub train_percentage: u8,
72+
6573
/// Directory to save the artifacts.
6674
#[arg(long, default_value = "/tmp/resnet-finetune")]
6775
artifact_dir: String,
@@ -72,7 +80,7 @@ pub struct Args {
7280

7381
/// Number of workers for data loading.
7482
#[arg(long, default_value = "2")]
75-
num_workers: Option<usize>,
83+
num_workers: usize,
7684

7785
/// Number of epochs to train the model.
7886
#[arg(long, default_value = "60")]
@@ -101,10 +109,11 @@ pub struct Args {
101109
/// Early stopping patience
102110
#[arg(long, default_value = "10")]
103111
patience: usize,
104-
}
105112

106-
#[allow(dead_code)]
107-
const ARTIFACT_DIR: &str = "/tmp/resnet-finetune";
113+
/// Optimizer Weight decay.
114+
#[arg(long, default_value_t = 5e-4)]
115+
pub weight_decay: f32,
116+
}
108117

109118
fn main() {
110119
let args = Args::parse();
@@ -115,101 +124,16 @@ fn main() {
115124
train::<Autodiff<Cuda>>(&args, &device);
116125
}
117126

118-
#[allow(dead_code)]
119-
fn run<B: Backend>(
120-
args: &Args,
121-
device: &B::Device,
122-
) {
123-
train::<Autodiff<B>>(args, device);
124-
// infer::<B>(ARTIFACT_DIR, device, 0.5);
125-
}
126-
127-
pub trait MultiLabelClassification<B: Backend> {
128-
fn forward_classification(
129-
&self,
130-
images: Tensor<B, 4>,
131-
targets: Tensor<B, 2, Int>,
132-
) -> MultiLabelClassificationOutput<B>;
133-
}
134-
135-
impl<B: Backend> MultiLabelClassification<B> for ResNet<B> {
136-
fn forward_classification(
137-
&self,
138-
images: Tensor<B, 4>,
139-
targets: Tensor<B, 2, Int>,
140-
) -> MultiLabelClassificationOutput<B> {
141-
let output = self.forward(images);
142-
let loss = BinaryCrossEntropyLossConfig::new()
143-
.with_logits(true)
144-
.init(&output.device())
145-
.forward(output.clone(), targets.clone());
146-
147-
MultiLabelClassificationOutput::new(loss, output, targets)
148-
}
149-
}
150-
151-
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, MultiLabelClassificationOutput<B>>
152-
for ResNet<B>
153-
{
154-
fn step(
155-
&self,
156-
batch: ClassificationBatch<B>,
157-
) -> TrainOutput<MultiLabelClassificationOutput<B>> {
158-
let item = self.forward_classification(batch.images, batch.targets);
159-
160-
TrainOutput::new(self, item.loss.backward(), item)
161-
}
162-
}
163-
164-
impl<B: Backend> ValidStep<ClassificationBatch<B>, MultiLabelClassificationOutput<B>>
165-
for ResNet<B>
166-
{
167-
fn step(
168-
&self,
169-
batch: ClassificationBatch<B>,
170-
) -> MultiLabelClassificationOutput<B> {
171-
self.forward_classification(batch.images, batch.targets)
172-
}
173-
}
174-
175-
#[derive(Config)]
176-
pub struct TrainingConfig {
177-
#[config(default = 5)]
178-
pub num_epochs: usize,
179-
180-
#[config(default = 24)]
181-
pub batch_size: usize,
182-
183-
#[config(default = 4)]
184-
pub num_workers: usize,
185-
186-
#[config(default = 42)]
187-
pub seed: u64,
188-
189-
#[config(default = 1e-3)]
190-
pub learning_rate: f64,
191-
192-
#[config(default = 5e-5)]
193-
pub weight_decay: f32,
194-
195-
#[config(default = 70)]
196-
pub train_percentage: u8,
197-
198-
pub num_classes: usize,
199-
}
200-
201-
fn create_artifact_dir(artifact_dir: &str) {
202-
// Remove existing artifacts before to get an accurate learner summary
203-
std::fs::remove_dir_all(artifact_dir).ok();
204-
std::fs::create_dir_all(artifact_dir).ok();
205-
}
206-
207127
pub fn train<B: AutodiffBackend>(
208128
args: &Args,
209129
device: &B::Device,
210130
) -> anyhow::Result<()> {
211-
let artifact_dir = args.artifact_dir.as_ref();
212-
create_artifact_dir(artifact_dir);
131+
// Remove existing artifacts before to get an accurate learner summary
132+
let artifact_dir: &str = args.artifact_dir.as_ref();
133+
std::fs::remove_dir_all(artifact_dir)?;
134+
std::fs::create_dir_all(artifact_dir)?;
135+
136+
B::seed(args.seed);
213137

214138
let disk_cache = DiskCacheConfig::default();
215139

@@ -220,54 +144,70 @@ pub fn train<B: AutodiffBackend>(
220144
.fetch_weights(&disk_cache)
221145
.expect("Failed to fetch pretrained weights");
222146

223-
let resnet_config = prefab
224-
.to_config()
225-
.with_activation(ActivationConfig::Gelu)
226-
.to_structure();
147+
let resnet_config = prefab.to_config().with_activation(ActivationConfig::Gelu);
227148

228149
let model: ResNet<B> = resnet_config
150+
.clone()
151+
.to_structure()
229152
.init(device)
230153
.load_pytorch_weights(weights)
231154
.expect("Failed to load pretrained weights")
232155
.with_classes(CLASSES.len())
233156
.with_stochastic_drop_block(args.drop_block_prob)
234157
.with_stochastic_path_depth(args.drop_path_prob);
235158

236-
// Config
237-
let training_config = TrainingConfig::new(CLASSES.len())
238-
.with_learning_rate(args.learning_rate)
239-
.with_num_epochs(args.num_epochs)
240-
.with_batch_size(args.batch_size);
241-
242159
let optimizer = AdamConfig::new()
243-
.with_weight_decay(Some(WeightDecayConfig::new(training_config.weight_decay)))
160+
.with_weight_decay(Some(WeightDecayConfig::new(args.weight_decay)))
244161
.init();
245162

246-
training_config
247-
.save(format!("{artifact_dir}/config.json"))
248-
.expect("Config should be saved successfully");
249-
250-
B::seed(training_config.seed);
163+
#[derive(Config)]
164+
struct LogConfig {
165+
seed: u64,
166+
train_percentage: u8,
167+
batch_size: usize,
168+
num_epochs: usize,
169+
resnet_prefab: String,
170+
resnet_pretrained: String,
171+
drop_block_prob: f64,
172+
drop_path_prob: f64,
173+
learning_rate: f64,
174+
patience: usize,
175+
weight_decay: f32,
176+
resnet: ResNetContractConfig,
177+
}
178+
LogConfig {
179+
seed: args.seed,
180+
train_percentage: args.train_percentage,
181+
batch_size: args.batch_size,
182+
num_epochs: args.num_epochs,
183+
resnet_prefab: args.resnet_prefab.clone(),
184+
resnet_pretrained: args.resnet_pretrained.clone(),
185+
drop_block_prob: args.drop_block_prob,
186+
drop_path_prob: args.drop_path_prob,
187+
learning_rate: args.learning_rate,
188+
patience: args.patience,
189+
weight_decay: args.weight_decay,
190+
resnet: resnet_config,
191+
}
192+
.save(format!("{artifact_dir}/config.json"))
193+
.expect("Config should be saved successfully");
251194

252195
// Dataloaders
253196
let batcher_train = ClassificationBatcher::<B>::new(device.clone());
254197
let batcher_valid = ClassificationBatcher::<B::InnerBackend>::new(device.clone());
255198

256-
let (train, valid) = ImageFolderDataset::planet_train_val_split(
257-
training_config.train_percentage,
258-
training_config.seed,
259-
)
260-
.unwrap();
199+
let (train, valid) =
200+
ImageFolderDataset::planet_train_val_split(args.train_percentage, args.seed).unwrap();
261201

262202
let dataloader_train = DataLoaderBuilder::new(batcher_train)
263-
.batch_size(training_config.batch_size)
264-
.shuffle(training_config.seed)
265-
.num_workers(training_config.num_workers)
266-
.build(ShuffledDataset::with_seed(train, training_config.seed));
203+
.batch_size(args.batch_size)
204+
.shuffle(args.seed)
205+
.num_workers(args.num_workers)
206+
.build(ShuffledDataset::with_seed(train, args.seed));
267207

268208
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
269-
.batch_size(training_config.batch_size)
270-
.num_workers(training_config.num_workers)
209+
.batch_size(args.batch_size)
210+
.num_workers(args.num_workers)
271211
.build(valid);
272212

273213
// Learner config
@@ -287,9 +227,9 @@ pub fn train<B: AutodiffBackend>(
287227
},
288228
))
289229
.devices(vec![device.clone()])
290-
.num_epochs(training_config.num_epochs)
230+
.num_epochs(args.num_epochs)
291231
.summary()
292-
.build(model, optimizer, training_config.learning_rate);
232+
.build(model, optimizer, args.learning_rate);
293233

294234
// Training
295235
let now = Instant::now();
@@ -303,3 +243,51 @@ pub fn train<B: AutodiffBackend>(
303243

304244
Ok(())
305245
}
246+
247+
pub trait MultiLabelClassification<B: Backend> {
248+
fn forward_classification(
249+
&self,
250+
images: Tensor<B, 4>,
251+
targets: Tensor<B, 2, Int>,
252+
) -> MultiLabelClassificationOutput<B>;
253+
}
254+
255+
impl<B: Backend> MultiLabelClassification<B> for ResNet<B> {
256+
fn forward_classification(
257+
&self,
258+
images: Tensor<B, 4>,
259+
targets: Tensor<B, 2, Int>,
260+
) -> MultiLabelClassificationOutput<B> {
261+
let output = self.forward(images);
262+
let loss = BinaryCrossEntropyLossConfig::new()
263+
.with_logits(true)
264+
.init(&output.device())
265+
.forward(output.clone(), targets.clone());
266+
267+
MultiLabelClassificationOutput::new(loss, output, targets)
268+
}
269+
}
270+
271+
impl<B: AutodiffBackend> TrainStep<ClassificationBatch<B>, MultiLabelClassificationOutput<B>>
272+
for ResNet<B>
273+
{
274+
fn step(
275+
&self,
276+
batch: ClassificationBatch<B>,
277+
) -> TrainOutput<MultiLabelClassificationOutput<B>> {
278+
let item = self.forward_classification(batch.images, batch.targets);
279+
280+
TrainOutput::new(self, item.loss.backward(), item)
281+
}
282+
}
283+
284+
impl<B: Backend> ValidStep<ClassificationBatch<B>, MultiLabelClassificationOutput<B>>
285+
for ResNet<B>
286+
{
287+
fn step(
288+
&self,
289+
batch: ClassificationBatch<B>,
290+
) -> MultiLabelClassificationOutput<B> {
291+
self.forward_classification(batch.images, batch.targets)
292+
}
293+
}

0 commit comments

Comments
 (0)