Skip to content

Commit 7f051b3

Browse files
committed
refactor: simplify artifact directory handling and update optimizer configuration
- Added `reset_artifact_dir` helper function in `resnet_finetune` for cleaner artifact management. - Replaced `AdamConfig` with `AdamWConfig` for optimizer setup. - Removed redundant imports and allowed directives for improved code clarity.
1 parent 25b3801 commit 7f051b3

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

crates/bimm/src/models/resnet/bottleneck_block.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -522,6 +522,7 @@ impl<B: Backend> BottleneckBlock<B> {
522522
#[cfg(test)]
523523
mod tests {
524524
use super::*;
525+
use bimm_contracts::assert_shape_contract;
525526
use burn::backend::NdArray;
526527
use burn::nn::activation::ActivationConfig;
527528

@@ -571,7 +572,6 @@ mod tests {
571572

572573
#[test]
573574
fn test_basic_block_forward_same_channels_no_downsample_autodiff() {
574-
use bimm_contracts::assert_shape_contract;
575575
use burn::backend::{Autodiff, Wgpu};
576576
type B = Autodiff<Wgpu>;
577577

@@ -603,7 +603,6 @@ mod tests {
603603

604604
#[test]
605605
fn test_basic_block_forward_downsample_drop_block_drop_path_autodiff() {
606-
use bimm_contracts::assert_shape_contract;
607606
use burn::backend::{Autodiff, Wgpu};
608607
type B = Autodiff<Wgpu>;
609608

examples/resnet_finetune/src/main.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#![allow(dead_code, unused)]
1+
#![allow(dead_code)]
22
#![recursion_limit = "256"]
33

44
extern crate core;
@@ -8,18 +8,16 @@ mod dataset;
88
use crate::data::{ClassificationBatch, ClassificationBatcher};
99
use crate::dataset::{CLASSES, PlanetLoader, download};
1010
use bimm::cache::disk::DiskCacheConfig;
11-
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet, ResNetContractConfig};
11+
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet};
1212
use burn::backend::Autodiff;
1313
use burn::config::Config;
1414
use burn::data::dataloader::DataLoaderBuilder;
1515
use burn::data::dataset::transform::ShuffledDataset;
1616
use burn::data::dataset::vision::ImageFolderDataset;
1717
use burn::module::Module;
18-
use burn::nn::PReluConfig;
1918
use burn::nn::activation::ActivationConfig;
2019
use burn::nn::loss::BinaryCrossEntropyLossConfig;
21-
use burn::optim::AdamConfig;
22-
use burn::optim::decay::WeightDecayConfig;
20+
use burn::optim::AdamWConfig;
2321
use burn::prelude::{Int, Tensor};
2422
use burn::record::CompactRecorder;
2523
use burn::tensor::backend::{AutodiffBackend, Backend};
@@ -152,6 +150,11 @@ fn main() -> anyhow::Result<()> {
152150
return train::<Autodiff<burn::backend::Metal>>(&args);
153151
}
154152

153+
pub fn reset_artifact_dir(artifact_dir: &str) -> anyhow::Result<()> {
154+
std::fs::remove_dir_all(artifact_dir)?;
155+
std::fs::create_dir_all(artifact_dir).map_err(|e| anyhow::anyhow!(e))
156+
}
157+
155158
#[must_use]
156159
pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
157160
let device: B::Device = Default::default();
@@ -189,8 +192,7 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
189192

190193
// Remove existing artifacts before to get an accurate learner summary
191194
let artifact_dir: &str = args.artifact_dir.as_ref();
192-
std::fs::remove_dir_all(artifact_dir);
193-
std::fs::create_dir_all(artifact_dir).expect("Failed to create artifacts directory");
195+
reset_artifact_dir(artifact_dir)?;
194196

195197
B::seed(&device, args.seed);
196198

@@ -224,8 +226,9 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
224226
resnet: model,
225227
};
226228

227-
let optimizer = AdamConfig::new()
228-
.with_weight_decay(Some(WeightDecayConfig::new(args.weight_decay)))
229+
let optimizer = AdamWConfig::new()
230+
.with_cautious_weight_decay(true)
231+
.with_weight_decay(args.weight_decay)
229232
.init();
230233

231234
LogConfig {

0 commit comments

Comments
 (0)