Skip to content

Commit d39bad4

Browse files
committed
feat: add pretrained model listing and parsing support in ResNet-finetune example
- Enabled listing of available pretrained models via the `resnet_pretrained='list'` option. - Updated argument parsing to handle `resnet34:tv_in1k` format for model and weights separation. - Refactored weight fetching logic to support the new argument structure.
1 parent d87d119 commit d39bad4

File tree

1 file changed

+27
-10
lines changed
  • examples/resnet-finetune/src

1 file changed

+27
-10
lines changed

examples/resnet-finetune/src/main.rs

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,9 @@ pub struct Args {
9090
#[arg(long, default_value = "60")]
9191
num_epochs: usize,
9292

93-
/// Resnet Model Config
94-
#[arg(long, default_value = "resnet34")]
95-
resnet_prefab: String,
96-
97-
/// Resnet Pretrained
98-
#[arg(long, default_value = "tv_in1k")]
93+
/// Pretrained Resnet Model.
94+
/// Use "list" to list all available pretrained models.
95+
#[arg(long, default_value = "resnet34.tv_in1k")]
9996
resnet_pretrained: String,
10097

10198
/// Drop Block Prob
@@ -164,6 +161,26 @@ fn main() -> anyhow::Result<()> {
164161
pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
165162
let device: B::Device = Default::default();
166163

164+
// TODO: lift to clap parser.
165+
if args.resnet_pretrained == "list" {
166+
println!("Available pretrained models:");
167+
for prefab in PREFAB_RESNET_MAP.items {
168+
if let Some(weights) = prefab.weights {
169+
for item in weights.items {
170+
println!("- \"{}.{}\": {}", prefab.name, item.name, item.description);
171+
}
172+
}
173+
}
174+
return Ok(());
175+
}
176+
let [resnet_prefab, resnet_pretrained] = args
177+
.resnet_pretrained
178+
.splitn(2, ".")
179+
.map(|s| s.to_string())
180+
.collect::<Vec<String>>()
181+
.try_into()
182+
.unwrap();
183+
167184
// Remove existing artifacts before to get an accurate learner summary
168185
let artifact_dir: &str = args.artifact_dir.as_ref();
169186
std::fs::remove_dir_all(artifact_dir);
@@ -173,10 +190,10 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
173190

174191
let disk_cache = DiskCacheConfig::default();
175192

176-
let prefab = PREFAB_RESNET_MAP.expect_lookup_prefab(&args.resnet_prefab);
193+
let prefab = PREFAB_RESNET_MAP.expect_lookup_prefab(&resnet_prefab);
177194

178195
let weights = prefab
179-
.expect_lookup_pretrained_weights(&args.resnet_pretrained)
196+
.expect_lookup_pretrained_weights(&resnet_pretrained)
180197
.fetch_weights(&disk_cache)
181198
.expect("Failed to fetch pretrained weights");
182199

@@ -201,8 +218,8 @@ pub fn train<B: AutodiffBackend>(args: &Args) -> anyhow::Result<()> {
201218
train_percentage: args.train_percentage,
202219
batch_size: args.batch_size,
203220
num_epochs: args.num_epochs,
204-
resnet_prefab: args.resnet_prefab.clone(),
205-
resnet_pretrained: args.resnet_pretrained.clone(),
221+
resnet_prefab: resnet_prefab.clone(),
222+
resnet_pretrained: resnet_pretrained.clone(),
206223
drop_block_prob: args.drop_block_prob,
207224
drop_path_prob: args.drop_path_prob,
208225
learning_rate: args.learning_rate,

0 commit comments

Comments
 (0)