Skip to content

Commit 7b63752

Browse files
authored
Fix bottleneck support. (#99)
* ResNet50 WIP * feat: add label smoothing and refactor ResNet-finetune model - Introduced optional label smoothing for training via `--smoothing` argument. - Refactored `ResNet` into a new `Host` structure to include smoothing and support advanced configurations. - Adjusted training and validation flows to utilize `Host` wrapping the original `ResNet` model. * chore: clean up outdated comments in ResNet-finetune example * refactor: reorganize ResNet block APIs and improve method clarity - Standardized comments in `BasicBlock` and `BottleneckBlock` to use consistent `meta-API` notation. - Merged duplicate methods (`pinch_factor`, `effective_first_dilation`, `out_planes`) for cleaner implementation. - Updated method usages to directly reference the respective properties for improved readability and efficiency. * refactor: rename and standardize Reduction APIs across ResNet blocks - Replaced `reduction_factor` with `reduce_first` for clarity and consistency. - Streamlined related methods and comments in `BasicBlock`, `BottleneckBlock`, and `LayerBlock`. - Added debug print utilities across ResNet modules to help with diagnostics. - Updated CNA APIs with `map_forward` and adjusted method order for improved readability. - Enhanced ResNet-finetune example with `--freeze_layers` option and a custom metrics renderer. - Adjusted default cardinality and pretrained weights handling in ResNet configurations. * refactor: rename `resnet-finetune` to `resnet_finetune` and update related documentation - Standardized naming across the example directory, README, and `Cargo.toml`. - Enhanced pretrained model listing with detailed configurations and descriptions. * chore: bump version to 0.19.2 across workspace and dependencies - Updated `Cargo.toml` and `Cargo.lock` files to reflect the new version.
1 parent 0de0908 commit 7b63752

File tree

26 files changed

+1036
-389
lines changed

26 files changed

+1036
-389
lines changed

Cargo.lock

Lines changed: 10 additions & 8 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ members = [
1010

1111
[workspace.package]
1212
edition = "2024"
13-
version = "0.19.1"
13+
version = "0.19.2"
1414
repository = "https://github.com/crutcher/bimm"
1515
license = "MIT"
1616
rust-version = "1.88.0"
@@ -32,7 +32,7 @@ dirs = "^6.0.0"
3232
# > cudarc = { workspace = true, features = ["nvrtc"], optional = true }
3333
cudarc = "0.16.6"
3434

35-
bimm-contracts = "^0.19.0"
35+
bimm-contracts = "^0.19.1"
3636

3737
# Burn coupled-dependencies
3838
globwalk = "^0.9.1"

README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,33 @@ let model: ResNet<Wgpu> = prefab
7575
.with_stochastic_path_depth(0.1);
7676
```
7777

78+
### Example [resnet_finetune](examples/resnet_finetune) - Pretrained ResNet finetuning example.
79+
80+
```terminaloutput
81+
$ cargo run --release -p resnet_finetune -- --pretrained list
82+
Available pretrained models:
83+
* "resnet18"
84+
ResNetContractConfig { layers: [2, 2, 2, 2], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: None, normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
85+
- "resnet18.tv_in1k": TorchVision ResNet-18
86+
- "resnet18.a1_in1k": RSB Paper ResNet-18 a1
87+
- "resnet18.a2_in1k": RSB Paper ResNet-18 a2
88+
- "resnet18.a3_in1k": RSB Paper ResNet-18 a3
89+
* "resnet34"
90+
ResNetContractConfig { layers: [3, 4, 6, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: None, normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
91+
- "resnet34.tv_in1k": TorchVision ResNet-34
92+
- "resnet34.a1_in1k": RSB Paper ResNet-32 a1
93+
- "resnet34.a2_in1k": RSB Paper ResNet-32 a2
94+
- "resnet34.a3_in1k": RSB Paper ResNet-32 a3
95+
- "resnet34.bt_in1k": ResNet-34 pretrained on ImageNet
96+
* "resnet50"
97+
ResNetContractConfig { layers: [3, 4, 6, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
98+
- "resnet50.tv_in1k": TorchVision ResNet-50
99+
* "resnet101"
100+
ResNetContractConfig { layers: [3, 4, 23, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
101+
- "resnet101.tv_in1k": TorchVision ResNet-101
102+
- "resnet101.a1_in1k": ResNet-101 pretrained on ImageNet
103+
```
104+
78105
### [bimm-contracts](https://github.com/crutcher/bimm-contracts) - a crate for static shape contracts for tensors.
79106

80107
[![Crates.io Version](https://img.shields.io/crates/v/bimm-contracts)](https://crates.io/crates/bimm-contracts)

crates/bimm-firehose-image/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ description = "bimm-firehose image processing support"
1212
workspace = true
1313

1414
[dependencies]
15-
bimm-firehose = { version = "0.19.1", path = "../bimm-firehose" }
15+
bimm-firehose = { version = "0.19.2", path = "../bimm-firehose" }
1616

1717
burn = { workspace = true, features = ["dataset", "vision", "ndarray"] }
1818
serde = { workspace = true, features = ["derive"] }

crates/bimm/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ burn-import = { workspace = true }
2121

2222
crc = { workspace = true }
2323
serde = { workspace = true, features = ["derive"] }
24+
serde_json = { workspace = true }
2425

2526
anyhow = { workspace = true }
2627
num-traits = { workspace = true }
2728

2829
[dev-dependencies]
29-
burn = { workspace = true, features = ["autodiff"] }
30+
burn = { workspace = true, features = ["autodiff", "wgpu"] }
3031
hamcrest = { workspace = true }
3132
indoc = { workspace = true }
3233
criterion = { workspace = true }

crates/bimm/src/layers/blocks/cna.rs

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,12 @@ pub trait CNA2dMeta {
5757
/// Number of input channels.
5858
fn in_channels(&self) -> usize;
5959

60-
/// Number of groups.
61-
fn groups(&self) -> usize;
62-
6360
/// Number of output channels.
6461
fn out_channels(&self) -> usize;
6562

63+
/// Number of groups.
64+
fn groups(&self) -> usize;
65+
6666
/// Get the stride.
6767
fn stride(&self) -> [usize; 2];
6868
}
@@ -88,14 +88,14 @@ impl CNA2dMeta for CNA2dConfig {
8888
self.conv.channels[0]
8989
}
9090

91-
fn groups(&self) -> usize {
92-
self.conv.groups
93-
}
94-
9591
fn out_channels(&self) -> usize {
9692
self.conv.channels[1]
9793
}
9894

95+
fn groups(&self) -> usize {
96+
self.conv.groups
97+
}
98+
9999
fn stride(&self) -> [usize; 2] {
100100
self.conv.stride
101101
}
@@ -153,15 +153,15 @@ pub struct CNA2d<B: Backend> {
153153

154154
impl<B: Backend> CNA2dMeta for CNA2d<B> {
155155
fn in_channels(&self) -> usize {
156-
self.conv.weight.shape().dims[1] * self.groups()
156+
self.conv.weight.dims()[1] * self.groups()
157157
}
158158

159-
fn groups(&self) -> usize {
160-
self.conv.groups
159+
fn out_channels(&self) -> usize {
160+
self.conv.weight.dims()[0]
161161
}
162162

163-
fn out_channels(&self) -> usize {
164-
self.conv.weight.shape().dims[0]
163+
fn groups(&self) -> usize {
164+
self.conv.groups
165165
}
166166

167167
fn stride(&self) -> [usize; 2] {
@@ -193,17 +193,17 @@ impl<B: Backend> CNA2d<B> {
193193
&self,
194194
input: Tensor<B, 4>,
195195
) -> Tensor<B, 4> {
196-
self.hook_forward(input, |x| x)
196+
self.map_forward(input, |x| x)
197197
}
198198

199-
/// Hooked Forward Pass.
199+
/// Mapping Forward Pass.
200200
///
201-
/// Applies the hook after normalization but before activation.
201+
/// Applies the callback fn after normalization but before activation.
202202
///
203203
/// ```rust,ignore
204204
/// let x = self.conv.forward(input);
205205
/// let x = self.norm.forward(x);
206-
/// let x = hook(x);
206+
/// let x = f(x);
207207
/// let x = self.act.forward(x);
208208
/// return x
209209
/// ```
@@ -212,14 +212,15 @@ impl<B: Backend> CNA2d<B> {
212212
///
213213
/// - `input`: \
214214
/// ``[batch, in_channels, in_height=out_height*stride, in_width=out_width*stride]``.
215+
/// - `f`: a callback endofunction, from/to ``[batch, in_channels, out_height, out_width]``.
215216
///
216217
/// # Returns
217218
///
218219
/// ``[batch, out_channels, out_height, out_width]``
219-
pub fn hook_forward<F>(
220+
pub fn map_forward<F>(
220221
&self,
221222
input: Tensor<B, 4>,
222-
hook: F,
223+
f: F,
223224
) -> Tensor<B, 4>
224225
where
225226
F: FnOnce(Tensor<B, 4>) -> Tensor<B, 4>,
@@ -254,7 +255,7 @@ impl<B: Backend> CNA2d<B> {
254255

255256
let x = self.norm.forward(x);
256257

257-
let x = hook(x);
258+
let x = f(x);
258259

259260
let x = self.act.forward(x);
260261

@@ -345,7 +346,7 @@ mod tests {
345346
{
346347
let hook = |x| x * 2.0;
347348

348-
let output = layer.hook_forward(input.clone(), hook);
349+
let output = layer.map_forward(input.clone(), hook);
349350
let expected = {
350351
let x = layer.conv.forward(input.clone());
351352
let x = layer.norm.forward(x);

0 commit comments

Comments
 (0)