Skip to content

Commit bab85b2

Browse files
committed
feat: enhance ResNet examples with gradient accumulation and updated defaults
- Added `grads_accumulation` argument to support gradient accumulation for better memory management. - Updated default values for learning rate, LR decay, batch size, drop block probability, and drop path probability. - Refactored `LogConfig` for streamlined documentation and improved modularity. - Improved stochastic depth initialization logic to enforce consistency within ResNet layers.
1 parent ac00550 commit bab85b2

File tree

4 files changed

+56
-42
lines changed

4 files changed

+56
-42
lines changed

crates/bimm/src/models/resnet/layer_block.rs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -243,16 +243,13 @@ impl LayerBlockStructureConfig {
243243
O: Into<Option<DropBlockOptions>>,
244244
{
245245
let options = options.into();
246-
self.map_blocks(&mut |_, block| block.with_drop_block(options.clone()))
247-
}
248-
249-
/// Update the drop path probability.
250-
pub fn with_drop_path_prob(
251-
self,
252-
prob: f64,
253-
) -> Self {
254-
let prob = expect_probability(prob);
255-
self.map_blocks(&mut |_, block| block.with_drop_path_prob(prob))
246+
self.map_blocks(&mut |idx, block| {
247+
if idx == 0 {
248+
block.with_drop_block(None)
249+
} else {
250+
block.with_drop_block(options.clone())
251+
}
252+
})
256253
}
257254
}
258255

crates/bimm/src/models/resnet/resnet_model.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,13 @@ impl ResNetStructureConfig {
214214
) -> Self {
215215
let drop_path_rate = expect_probability(drop_path_rate);
216216

217-
let net_num_blocks = self.layers.iter().map(|b| b.len()).sum::<usize>();
217+
let net_num_blocks = self.layers.iter().map(|b| b.len()).sum::<usize>() - self.layers.len();
218218
let mut net_block_idx = 0;
219-
let mut update_drop_path = |_idx: usize, block: ResidualBlockStructureConfig| {
219+
let mut update_drop_path = |idx: usize, block: ResidualBlockStructureConfig| {
220220
// stochastic depth linear decay rule
221221
let block_dpr = drop_path_rate * (net_block_idx as f64) / ((net_num_blocks - 1) as f64);
222222
net_block_idx += 1;
223-
if block_dpr > 0.0 {
223+
if idx != 0 && block_dpr > 0.0 {
224224
block.with_drop_path_prob(block_dpr)
225225
} else {
226226
block

examples/resnet-finetune/src/main.rs

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,15 @@ pub struct Args {
7575
artifact_dir: String,
7676

7777
/// Batch size for processing
78-
#[arg(short, long, default_value_t = 32)]
78+
#[arg(short, long, default_value_t = 24)]
7979
batch_size: usize,
8080

81+
/// Grads accumulation size for processing
82+
#[arg(short, long, default_value_t = 8)]
83+
grads_accumulation: usize,
84+
8185
/// Number of workers for data loading.
82-
#[arg(long, default_value = "2")]
86+
#[arg(long, default_value = "4")]
8387
num_workers: usize,
8488

8589
/// Number of epochs to train the model.
@@ -103,55 +107,62 @@ pub struct Args {
103107
drop_path_prob: f64,
104108

105109
/// Learning rate
106-
#[arg(long, default_value = "1e-5")]
110+
#[arg(long, default_value_t = 5e-5)]
107111
pub learning_rate: f64,
108112

109113
/// Early stopping patience
110-
#[arg(long, default_value = "10")]
114+
#[arg(long, default_value_t = 10)]
111115
patience: usize,
112116

113117
/// Optimizer Weight decay.
114118
#[arg(long, default_value_t = 5e-4)]
115119
pub weight_decay: f32,
116120
}
117121

118-
/// Log config.
119-
///
120-
/// Only exists for logging.
121-
#[derive(Config, Debug)]
122122
#[allow(clippy::too_many_arguments)]
123-
pub struct LogConfig {
124-
seed: u64,
125-
train_percentage: u8,
126-
batch_size: usize,
127-
num_epochs: usize,
128-
resnet_prefab: String,
129-
resnet_pretrained: String,
130-
drop_block_prob: f64,
131-
drop_path_prob: f64,
132-
learning_rate: f64,
133-
patience: usize,
134-
weight_decay: f32,
135-
resnet: ResNetContractConfig,
123+
mod local {
124+
use bimm::models::resnet::ResNetContractConfig;
125+
use burn::config::Config;
126+
127+
/// Log config.
128+
///
129+
/// Only exists for logging.
130+
#[derive(Config, Debug)]
131+
pub struct LogConfig {
132+
pub seed: u64,
133+
pub train_percentage: u8,
134+
pub batch_size: usize,
135+
pub num_epochs: usize,
136+
pub resnet_prefab: String,
137+
pub resnet_pretrained: String,
138+
pub drop_block_prob: f64,
139+
pub drop_path_prob: f64,
140+
pub learning_rate: f64,
141+
pub patience: usize,
142+
pub weight_decay: f32,
143+
pub resnet: ResNetContractConfig,
144+
}
136145
}
146+
use local::*;
137147

138-
fn main() {
148+
fn main() -> anyhow::Result<()> {
139149
let args = Args::parse();
140150

141151
let _source_tree = download();
142152

143153
let device = Default::default();
144-
train::<Autodiff<Cuda>>(&args, &device);
154+
train::<Autodiff<Cuda>>(&args, &device)
145155
}
146156

157+
#[must_use]
147158
pub fn train<B: AutodiffBackend>(
148159
args: &Args,
149160
device: &B::Device,
150161
) -> anyhow::Result<()> {
151162
// Remove existing artifacts before to get an accurate learner summary
152163
let artifact_dir: &str = args.artifact_dir.as_ref();
153-
std::fs::remove_dir_all(artifact_dir)?;
154-
std::fs::create_dir_all(artifact_dir)?;
164+
std::fs::remove_dir_all(artifact_dir);
165+
std::fs::create_dir_all(artifact_dir).expect("Failed to create artifacts directory");
155166

156167
B::seed(args.seed);
157168

@@ -232,6 +243,7 @@ pub fn train<B: AutodiffBackend>(
232243
},
233244
))
234245
.devices(vec![device.clone()])
246+
.grads_accumulation(args.grads_accumulation)
235247
.num_epochs(args.num_epochs)
236248
.summary()
237249
.build(model, optimizer, args.learning_rate);

examples/resnet_tiny/src/main.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ pub struct Args {
6666
#[arg(short, long, default_value_t = 512)]
6767
batch_size: usize,
6868

69+
/// Grads accumulation size for processing
70+
#[arg(short, long, default_value_t = 8)]
71+
grads_accumulation: usize,
72+
6973
/// Number of workers for data loading.
7074
#[arg(long, default_value = "2")]
7175
num_workers: Option<usize>,
@@ -79,11 +83,11 @@ pub struct Args {
7983
drop_block_rate: f64,
8084

8185
/// Learning rate for the optimizer.
82-
#[arg(long, default_value = "1.0e-6")]
86+
#[arg(long, default_value = "1.0e-4")]
8387
learning_rate: f64,
8488

8589
/// Learning rate decay gamma.
86-
#[arg(long, default_value = "0.999975")]
90+
#[arg(long, default_value = "0.999997")]
8791
lr_gamma: f64,
8892

8993
/// Directory to save the artifacts.
@@ -107,11 +111,11 @@ pub struct Args {
107111
resnet_pretrained: Option<String>,
108112

109113
/// Drop Block Prob
110-
#[arg(long, default_value = "0.25")]
114+
#[arg(long, default_value = "0.20")]
111115
drop_block_prob: f64,
112116

113117
/// Drop Path Prob
114-
#[arg(long, default_value = "0.15")]
118+
#[arg(long, default_value = "0.0")]
115119
drop_path_prob: f64,
116120

117121
/// Early stopping patience
@@ -315,6 +319,7 @@ pub fn backend_main<B: AutodiffBackend>(
315319
},
316320
))
317321
.devices(devices.clone())
322+
.grads_accumulation(args.grads_accumulation)
318323
.num_epochs(args.num_epochs)
319324
.summary()
320325
.build(model, optim_config.init(), lr_scheduler);

0 commit comments

Comments
 (0)