Skip to content

Commit 6a18c55

Browse files
committed
docs: update ResNet examples for clarity and improved prefab usage
- Revised ResNet examples in documentation to align with updated prefab and weight-fetching methods. - Enhanced example readability with consistent formatting and commented steps for better guidance. - Updated references to new prefab and weight-loading APIs.
1 parent e76d07c commit 6a18c55

File tree

3 files changed

+36
-27
lines changed

3 files changed

+36
-27
lines changed

README.md

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -45,28 +45,30 @@ See the [CONTRIBUTING](CONTRIBUTING.md) guide for build and contribution instruc
4545

4646
#### Example
4747

48-
Example of building a pretrained ResNet-18 module:
48+
Example of building a pretrained model:
4949

5050
```rust,no_run
51-
use bimm::cache::fetch_model_weights;
52-
use bimm::models::resnet::{ResNet, ResNetAbstractConfig};
53-
use burn::backend::NdArray;
51+
use burn::backend::Wgpu;
52+
use bimm::cache::disk::DiskCacheConfig;
53+
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet};
5454
5555
let device = Default::default();
5656
57-
let source =
58-
"https://download.pytorch.org/models/resnet18-f37072fd.pth";
59-
let source_classes = 1000;
60-
let weights_path= fetch_model_weights(source).unwrap();
57+
let prefab = PREFAB_RESNET_MAP.expect_lookup_prefab("resnet18");
6158
62-
let my_classes = 10;
59+
let weights = prefab
60+
.expect_lookup_pretrained_weights("tv_in1k")
61+
.fetch_weights(&DiskCacheConfig::default())
62+
.expect("Failed to fetch weights");
6363
64-
let model: ResNet<NdArray> = ResNetAbstractConfig::resnet18(source_classes)
64+
let model: ResNet<Wgpu> = prefab
65+
.to_config()
6566
.to_structure()
6667
.init(&device)
67-
.load_pytorch_weights(weights_path)
68-
.expect("Model should be loaded successfully")
69-
.with_classes(my_classes)
68+
.load_pytorch_weights(weights)
69+
.expect("Failed to load weights")
70+
// re-head the model to 10 classes:
71+
.with_classes(10)
7072
// Enable (drop_block_prob) stochastic block drops for training:
7173
.with_stochastic_drop_block(0.2)
7274
// Enable (drop_path_prob) stochastic depth for training:

crates/bimm/README.md

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,29 +5,30 @@
55

66
This is a Rust crate for image models, inspired by the Python `timm` package.
77

8-
Examples of loading pre-trained ResNet-18 model:
8+
Examples of loading pretrained model:
99

1010
```rust,no_run
11-
use bimm::cache::fetch_model_weights;
12-
use bimm::models::resnet::{ResNet, ResNetAbstractConfig};
1311
use burn::backend::Wgpu;
12+
use bimm::cache::disk::DiskCacheConfig;
13+
use bimm::models::resnet::{PREFAB_RESNET_MAP, ResNet};
1414
15-
type B = Wgpu;
1615
let device = Default::default();
1716
18-
let source =
19-
"https://download.pytorch.org/models/resnet18-f37072fd.pth";
20-
let source_classes = 1000;
21-
let weights_path= fetch_model_weights(source).unwrap();
17+
let prefab = PREFAB_RESNET_MAP.expect_lookup_prefab("resnet18");
2218
23-
let my_classes = 10;
19+
let weights = prefab
20+
.expect_lookup_pretrained_weights("tv_in1k")
21+
.fetch_weights(&DiskCacheConfig::default())
22+
.expect("Failed to fetch weights");
2423
25-
let model: ResNet<B> = ResNetAbstractConfig::resnet18(source_classes)
24+
let model: ResNet<Wgpu> = prefab
25+
.to_config()
2626
.to_structure()
2727
.init(&device)
28-
.load_pytorch_weights(weights_path)
29-
.expect("Model should be loaded successfully")
30-
.with_classes(my_classes)
28+
.load_pytorch_weights(weights)
29+
.expect("Failed to load weights")
30+
// re-head the model to 10 classes:
31+
.with_classes(10)
3132
// Enable (drop_block_prob) stochastic block drops for training:
3233
.with_stochastic_drop_block(0.2)
3334
// Enable (drop_path_prob) stochastic depth for training:

crates/bimm/src/models/resnet/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
//!
66
//! ## Example
77
//!
8+
//! Examples of loading pretrained model:
9+
//!
810
//! ```rust,no_run
911
//! use burn::backend::NdArray;
1012
//! use bimm::cache::disk::DiskCacheConfig;
@@ -25,8 +27,12 @@
2527
//! .init(&device)
2628
//! .load_pytorch_weights(weights)
2729
//! .expect("Failed to load weights")
30+
//! // re-head the model to 10 classes:
2831
//! .with_classes(10)
29-
//! .with_stochastic_drop_block(0.2);
32+
//! // Enable (drop_block_prob) stochastic block drops for training:
33+
//! .with_stochastic_drop_block(0.2)
34+
//! // Enable (drop_path_prob) stochastic depth for training:
35+
//! .with_stochastic_path_depth(0.1);
3036
//! ```
3137
//!
3238
//! ## Configuration

0 commit comments

Comments
 (0)