Skip to content

Commit 2667080

Browse files
committed
feat: add ResNet-26 and ResNet-152 pretrained models
- Introduced configurations for ResNet-26 and ResNet-152 with pretrained weights. - Updated `README` and example documentation to include new models and their weight descriptors.
1 parent ace7d96 commit 2667080

File tree

3 files changed

+73
-32
lines changed

3 files changed

+73
-32
lines changed

README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,16 @@ let model: ResNet<Wgpu> = prefab
7878
### Example [resnet_finetune](examples/resnet_finetune) - Pretrained ResNet finetuning example.
7979

8080
```terminaloutput
81-
$ cargo run --release -p resnet_finetune -- --pretrained list
8281
Available pretrained models:
8382
* "resnet18"
8483
ResNetContractConfig { layers: [2, 2, 2, 2], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: None, normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
8584
- "resnet18.tv_in1k": TorchVision ResNet-18
8685
- "resnet18.a1_in1k": RSB Paper ResNet-18 a1
8786
- "resnet18.a2_in1k": RSB Paper ResNet-18 a2
8887
- "resnet18.a3_in1k": RSB Paper ResNet-18 a3
88+
* "resnet26"
89+
ResNetContractConfig { layers: [2, 2, 2, 2], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
90+
- "resnet26.bt_in1k": ResNet-26 pretrained on ImageNet
8991
* "resnet34"
9092
ResNetContractConfig { layers: [3, 4, 6, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: None, normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
9193
- "resnet34.tv_in1k": TorchVision ResNet-34
@@ -100,6 +102,9 @@ ResNetContractConfig { layers: [3, 4, 6, 3], num_classes: 1000, stem_width: 64,
100102
ResNetContractConfig { layers: [3, 4, 23, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
101103
- "resnet101.tv_in1k": TorchVision ResNet-101
102104
- "resnet101.a1_in1k": ResNet-101 pretrained on ImageNet
105+
* "resnet152"
106+
ResNetContractConfig { layers: [3, 8, 36, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
107+
- "resnet152.tv_in1k": TorchVision ResNet-152
103108
```
104109

105110
### [bimm-contracts](https://github.com/crutcher/bimm-contracts) - a crate for static shape contracts for tensors.

crates/bimm/src/models/resnet/pretrained.rs

Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,23 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
7979
],
8080
}),
8181
},
82+
&StaticPreFabConfig {
83+
name: "resnet26",
84+
description: "ResNet-26 [2, 2, 2, 2] Bottleneck",
85+
builder: || ResNetContractConfig::new(vec![2, 2, 2, 2], 1000).with_bottleneck(true),
86+
87+
weights: Some(&StaticPretrainedWeightsMap {
88+
items: &[&StaticPretrainedWeightsDescriptor {
89+
name: "bt_in1k",
90+
description: "ResNet-26 pretrained on ImageNet",
91+
license: None,
92+
origin: None,
93+
urls: &[
94+
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet26-9aa10e23.pth",
95+
],
96+
}],
97+
}),
98+
},
8299
&StaticPreFabConfig {
83100
name: "resnet34",
84101
description: "ResNet-34 [3, 4, 6, 3] BasicBlocks",
@@ -146,38 +163,38 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
146163
builder: || ResNetContractConfig::new(vec![3, 4, 6, 3], 1000).with_bottleneck(true),
147164

148165
weights: Some(&StaticPretrainedWeightsMap {
149-
items: &[
150-
&StaticPretrainedWeightsDescriptor {
151-
name: "tv_in1k",
152-
description: "TorchVision ResNet-50",
153-
license: Some("bsd-3-clause"),
154-
origin: Some("https://github.com/pytorch/vision"),
155-
urls: &["https://download.pytorch.org/models/resnet50-0676ba61.pth"],
156-
},
157-
/*
158-
// ERROR: Some<Downsample> stub cannot be applied to None
159-
&StaticPretrainedWeightsDescriptor {
160-
name: "tv_in2k",
161-
description: "ResNet-50 pretrained on ImageNet",
162-
license: Some("bsd-3-clause"),
163-
origin: Some("https://github.com/pytorch/vision"),
164-
urls: &["https://download.pytorch.org/models/resnet50-11ad3fa6.pth"],
165-
},
166-
&StaticPretrainedWeightsDescriptor {
167-
name: "a1_in1k",
168-
description: "ResNet-50 pretrained on ImageNet",
169-
license: Some("bsd-3-clause"),
170-
origin: Some(
171-
"https://github.com/huggingface/pytorch-image-models/releases",
172-
),
173-
urls: &[
174-
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_a1_0-14fe96d1.pth",
175-
],
176-
},
177-
*/
178-
],
166+
items: &[&StaticPretrainedWeightsDescriptor {
167+
name: "tv_in1k",
168+
description: "TorchVision ResNet-50",
169+
license: Some("bsd-3-clause"),
170+
origin: Some("https://github.com/pytorch/vision"),
171+
urls: &["https://download.pytorch.org/models/resnet50-0676ba61.pth"],
172+
}],
179173
}),
180174
},
175+
/*
176+
&StaticPreFabConfig {
177+
name: "resnet50_gn",
178+
description: "ResNet-50 [3, 4, 6, 3] Bottleneck with GroupNorm",
179+
builder: || {
180+
ResNetContractConfig::new(vec![3, 4, 6, 3], 1000)
181+
.with_normalization(NormalizationConfig::Group(GroupNormConfig::new(32, 0)))
182+
.with_bottleneck(true)
183+
},
184+
185+
weights: Some(&StaticPretrainedWeightsMap {
186+
items: &[&StaticPretrainedWeightsDescriptor {
187+
name: "a1h_in1k",
188+
description: "ResNet-50 with GroupNorm pretrained on ImageNet",
189+
license: None,
190+
origin: None,
191+
urls: &[
192+
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet50_gn_a1h2-8fe6c4d0.pth",
193+
],
194+
}],
195+
}),
196+
},
197+
*/
181198
&StaticPreFabConfig {
182199
name: "resnet101",
183200
description: "ResNet-101 [3, 4, 23, 3] Bottleneck",
@@ -205,5 +222,19 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
205222
],
206223
}),
207224
},
225+
&StaticPreFabConfig {
226+
name: "resnet152",
227+
description: "ResNet-152 [3, 8, 36, 3] Bottleneck",
228+
builder: || ResNetContractConfig::new(vec![3, 8, 36, 3], 1000).with_bottleneck(true),
229+
weights: Some(&StaticPretrainedWeightsMap {
230+
items: &[&StaticPretrainedWeightsDescriptor {
231+
name: "tv_in1k",
232+
description: "TorchVision ResNet-152",
233+
license: Some("bsd-3-clause"),
234+
origin: Some("https://github.com/pytorch/vision"),
235+
urls: &["https://download.pytorch.org/models/resnet152-394f9c45.pth"],
236+
}],
237+
}),
238+
},
208239
],
209240
};

examples/resnet_finetune/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@ cargo run --release -p resnet_finetune
1616
This will list all available pretrained models:
1717

1818
```terminaloutput
19-
$ cargo run --release -p resnet_finetune -- --pretrained list
2019
Available pretrained models:
2120
* "resnet18"
2221
ResNetContractConfig { layers: [2, 2, 2, 2], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: None, normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
2322
- "resnet18.tv_in1k": TorchVision ResNet-18
2423
- "resnet18.a1_in1k": RSB Paper ResNet-18 a1
2524
- "resnet18.a2_in1k": RSB Paper ResNet-18 a2
2625
- "resnet18.a3_in1k": RSB Paper ResNet-18 a3
26+
* "resnet26"
27+
ResNetContractConfig { layers: [2, 2, 2, 2], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
28+
- "resnet26.bt_in1k": ResNet-26 pretrained on ImageNet
2729
* "resnet34"
2830
ResNetContractConfig { layers: [3, 4, 6, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: None, normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
2931
- "resnet34.tv_in1k": TorchVision ResNet-34
@@ -38,4 +40,7 @@ ResNetContractConfig { layers: [3, 4, 6, 3], num_classes: 1000, stem_width: 64,
3840
ResNetContractConfig { layers: [3, 4, 23, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
3941
- "resnet101.tv_in1k": TorchVision ResNet-101
4042
- "resnet101.a1_in1k": ResNet-101 pretrained on ImageNet
43+
* "resnet152"
44+
ResNetContractConfig { layers: [3, 8, 36, 3], num_classes: 1000, stem_width: 64, output_stride: 32, bottleneck_policy: Some(BottleneckPolicyConfig { pinch_factor: 4 }), normalization: Batch(BatchNormConfig { num_features: 0, epsilon: 1e-5, momentum: 0.1 }), activation: Relu }
45+
- "resnet152.tv_in1k": TorchVision ResNet-152
4146
```

0 commit comments

Comments
 (0)