Skip to content

Commit 88aa132

Browse files
committed
feat: extend ResNet pretrained weights with new variants
- Added new pretrained weight descriptors for ResNet-18 and ResNet-34 (a1_in1k, a2_in1k, a3_in1k). - Introduced additional ResNet-34 variant (bt_in1k) with updated weight URLs. - Adjusted descriptions to maintain consistency across model definitions.
1 parent e98b573 commit 88aa132

File tree

1 file changed

+92
-14
lines changed

1 file changed

+92
-14
lines changed

crates/bimm/src/models/resnet/pretrained.rs

Lines changed: 92 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,38 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
4444
items: &[
4545
&StaticPretrainedWeightsDescriptor {
4646
name: "tv_in1k",
47-
description: "ResNet18 pretrained on ImageNet",
47+
description: "ResNet-18 pretrained on ImageNet",
4848
license: Some("bsd-3-clause"),
4949
origin: Some("https://github.com/pytorch/vision"),
5050
urls: &["https://download.pytorch.org/models/resnet18-f37072fd.pth"],
5151
},
5252
&StaticPretrainedWeightsDescriptor {
5353
name: "a1_in1k",
54-
description: "ResNet18 pretrained on ImageNet",
54+
description: "ResNet-18 pretrained on ImageNet",
5555
license: None,
5656
origin: Some("https://github.com/huggingface/pytorch-image-models"),
5757
urls: &[
5858
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a1_0-d63eafa0.pth",
5959
],
6060
},
61+
&StaticPretrainedWeightsDescriptor {
62+
name: "a2_in1k",
63+
description: "ResNet-18 pretrained on ImageNet",
64+
license: None,
65+
origin: Some("https://github.com/huggingface/pytorch-image-models"),
66+
urls: &[
67+
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a2_0-b61bd467.pth",
68+
],
69+
},
70+
&StaticPretrainedWeightsDescriptor {
71+
name: "a3_in1k",
72+
description: "ResNet-18 pretrained on ImageNet",
73+
license: None,
74+
origin: Some("https://github.com/huggingface/pytorch-image-models"),
75+
urls: &[
76+
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet18_a3_0-40c531c8.pth",
77+
],
78+
},
6179
],
6280
}),
6381
},
@@ -67,13 +85,59 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
6785
builder: || ResNetContractConfig::new([3, 4, 6, 3], 1000),
6886

6987
weights: Some(&StaticPretrainedWeightsMap {
70-
items: &[&StaticPretrainedWeightsDescriptor {
71-
name: "tv_in1k",
72-
description: "ResNet-34 pretrained on ImageNet",
73-
license: Some("bsd-3-clause"),
74-
origin: Some("https://github.com/pytorch/vision"),
75-
urls: &["https://download.pytorch.org/models/resnet34-b627a593.pth"],
76-
}],
88+
items: &[
89+
&StaticPretrainedWeightsDescriptor {
90+
name: "tv_in1k",
91+
description: "ResNet-34 pretrained on ImageNet",
92+
license: Some("bsd-3-clause"),
93+
origin: Some("https://github.com/pytorch/vision"),
94+
urls: &["https://download.pytorch.org/models/resnet34-b627a593.pth"],
95+
},
96+
&StaticPretrainedWeightsDescriptor {
97+
name: "a1_in1k",
98+
description: "ResNet-34 pretrained on ImageNet",
99+
license: None,
100+
origin: Some(
101+
"https://github.com/huggingface/pytorch-image-models/releases",
102+
),
103+
urls: &[
104+
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a1_0-46f8f793.pth",
105+
],
106+
},
107+
&StaticPretrainedWeightsDescriptor {
108+
name: "a2_in1k",
109+
description: "ResNet-34 pretrained on ImageNet",
110+
license: None,
111+
origin: Some(
112+
"https://github.com/huggingface/pytorch-image-models/releases",
113+
),
114+
urls: &[
115+
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a2_0-82d47d71.pth",
116+
],
117+
},
118+
&StaticPretrainedWeightsDescriptor {
119+
name: "a3_in1k",
120+
description: "ResNet-34 pretrained on ImageNet",
121+
license: None,
122+
origin: Some(
123+
"https://github.com/huggingface/pytorch-image-models/releases",
124+
),
125+
urls: &[
126+
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet34_a3_0-a20cabb6.pth",
127+
],
128+
},
129+
&StaticPretrainedWeightsDescriptor {
130+
name: "bt_in1k",
131+
description: "ResNet-34 pretrained on ImageNet",
132+
license: None,
133+
origin: Some(
134+
"https://github.com/huggingface/pytorch-image-models/releases",
135+
),
136+
urls: &[
137+
"https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/resnet34-43635321.pth",
138+
],
139+
},
140+
],
77141
}),
78142
},
79143
&StaticPreFabConfig {
@@ -84,8 +148,7 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
84148
weights: Some(&StaticPretrainedWeightsMap {
85149
items: &[
86150
/*
87-
FIXME: The loaded weights have a downsample that the config does not have.
88-
DeserializeError("Candle Tensor error: invalid Zip archive: Could not find central directory end")
151+
// ERROR: Some<Downsample> stub cannot be applied to None
89152
&StaticPretrainedWeightsDescriptor {
90153
name: "tv_in1k",
91154
description: "ResNet-50 pretrained on ImageNet",
@@ -100,15 +163,30 @@ pub static PREFAB_RESNET_MAP: StaticPreFabMap<ResNetContractConfig> = StaticPreF
100163
origin: Some("https://github.com/pytorch/vision"),
101164
urls: &["https://download.pytorch.org/models/resnet50-11ad3fa6.pth"],
102165
},
103-
*/
166+
*/
104167
],
105168
}),
106169
},
107170
&StaticPreFabConfig {
108171
name: "resnet101",
109172
description: "ResNet-101 [3, 4, 23, 3] Bottleneck",
110-
builder: || ResNetContractConfig::new([3, 4, 6, 3], 1000).with_bottleneck(true),
111-
weights: None,
173+
builder: || ResNetContractConfig::new([3, 4, 23, 3], 1000).with_bottleneck(true),
174+
weights: Some(&StaticPretrainedWeightsMap {
175+
items: &[
176+
/*
177+
// ERROR: Some<Downsample> stub cannot be applied to None
178+
&StaticPretrainedWeightsDescriptor {
179+
name: "a1_in1k",
180+
description: "ResNet-101 pretrained on ImageNet",
181+
license: None,
182+
origin: Some("https://github.com/huggingface/pytorch-image-models/releases"),
183+
urls: &[
184+
"https://github.com/huggingface/pytorch-image-models/releases/download/v0.1-rsb-weights/resnet101_a1_0-cdcb52a9.pth",
185+
],
186+
}
187+
*/
188+
],
189+
}),
112190
},
113191
],
114192
};

0 commit comments

Comments
 (0)