Skip to content

Commit 9811e22

Browse files
committed
Fix regression in models with 1001 class pretrained weights. Improve batchnorm arg and BatchNormAct layer handling in several models.
1 parent aaa715b commit 9811e22

File tree

15 files changed

+157
-147
lines changed

15 files changed

+157
-147
lines changed

tests/test_models.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def test_model_default_cfgs(model_name, batch_size):
8383
cfg = model.default_cfg
8484

8585
classifier = cfg['classifier']
86-
first_conv = cfg['first_conv']
8786
pool_size = cfg['pool_size']
8887
input_size = model.default_cfg['input_size']
8988

@@ -111,9 +110,16 @@ def test_model_default_cfgs(model_name, batch_size):
111110
# FIXME mobilenetv3 forward_features vs removed pooling differ
112111
assert outputs.shape[-1] == pool_size[-1] and outputs.shape[-2] == pool_size[-2]
113112

114-
# check classifier and first convolution names match those in default_cfg
113+
# check classifier name matches default_cfg
115114
assert classifier + ".weight" in state_dict.keys(), f'{classifier} not in model params'
116-
assert first_conv + ".weight" in state_dict.keys(), f'{first_conv} not in model params'
115+
116+
# check first conv(s) names match default_cfg
117+
first_conv = cfg['first_conv']
118+
if isinstance(first_conv, str):
119+
first_conv = (first_conv,)
120+
assert isinstance(first_conv, (tuple, list))
121+
for fc in first_conv:
122+
assert fc + ".weight" in state_dict.keys(), f'{fc} not in model params'
117123

118124

119125
if 'GITHUB_ACTIONS' not in os.environ:

timm/models/dpn.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Hacked together by / Copyright 2020 Ross Wightman
88
"""
99
from collections import OrderedDict
10+
from functools import partial
1011
from typing import Tuple
1112

1213
import torch
@@ -173,12 +174,14 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
173174
self.drop_rate = drop_rate
174175
self.b = b
175176
assert output_stride == 32 # FIXME look into dilation support
177+
norm_layer = partial(BatchNormAct2d, eps=.001)
178+
fc_norm_layer = partial(BatchNormAct2d, eps=.001, act_layer=fc_act, inplace=False)
176179
bw_factor = 1 if small else 4
177180
blocks = OrderedDict()
178181

179182
# conv1
180183
blocks['conv1_1'] = ConvBnAct(
181-
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_kwargs=dict(eps=.001))
184+
in_chans, num_init_features, kernel_size=3 if small else 7, stride=2, norm_layer=norm_layer)
182185
blocks['conv1_pool'] = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
183186
self.feature_info = [dict(num_chs=num_init_features, reduction=2, module='features.conv1_1')]
184187

@@ -226,8 +229,7 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
226229
in_chs += inc
227230
self.feature_info += [dict(num_chs=in_chs, reduction=32, module=f'features.conv5_{k_sec[3]}')]
228231

229-
def _fc_norm(f, eps): return BatchNormAct2d(f, eps=eps, act_layer=fc_act, inplace=False)
230-
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=_fc_norm)
232+
blocks['conv5_bn_ac'] = CatBnAct(in_chs, norm_layer=fc_norm_layer)
231233

232234
self.num_features = in_chs
233235
self.features = nn.Sequential(blocks)

timm/models/gluon_xception.py

Lines changed: 25 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,8 @@
4242

4343

4444
class SeparableConv2d(nn.Module):
45-
def __init__(self, inplanes, planes, kernel_size=3, stride=1,
46-
dilation=1, bias=False, norm_layer=None, norm_kwargs=None):
45+
def __init__(self, inplanes, planes, kernel_size=3, stride=1, dilation=1, bias=False, norm_layer=None):
4746
super(SeparableConv2d, self).__init__()
48-
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
4947
self.kernel_size = kernel_size
5048
self.dilation = dilation
5149

@@ -54,7 +52,7 @@ def __init__(self, inplanes, planes, kernel_size=3, stride=1,
5452
self.conv_dw = nn.Conv2d(
5553
inplanes, inplanes, kernel_size, stride=stride,
5654
padding=padding, dilation=dilation, groups=inplanes, bias=bias)
57-
self.bn = norm_layer(num_features=inplanes, **norm_kwargs)
55+
self.bn = norm_layer(num_features=inplanes)
5856
# pointwise convolution
5957
self.conv_pw = nn.Conv2d(inplanes, planes, kernel_size=1, bias=bias)
6058

@@ -66,10 +64,8 @@ def forward(self, x):
6664

6765

6866
class Block(nn.Module):
69-
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
70-
norm_layer=None, norm_kwargs=None, ):
67+
def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True, norm_layer=None):
7168
super(Block, self).__init__()
72-
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
7369
if isinstance(planes, (list, tuple)):
7470
assert len(planes) == 3
7571
else:
@@ -80,17 +76,16 @@ def __init__(self, inplanes, planes, stride=1, dilation=1, start_with_relu=True,
8076
self.skip = nn.Sequential()
8177
self.skip.add_module('conv1', nn.Conv2d(
8278
inplanes, outplanes, 1, stride=stride, bias=False)),
83-
self.skip.add_module('bn1', norm_layer(num_features=outplanes, **norm_kwargs))
79+
self.skip.add_module('bn1', norm_layer(num_features=outplanes))
8480
else:
8581
self.skip = None
8682

8783
rep = OrderedDict()
8884
for i in range(3):
8985
rep['act%d' % (i + 1)] = nn.ReLU(inplace=True)
9086
rep['conv%d' % (i + 1)] = SeparableConv2d(
91-
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation,
92-
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
93-
rep['bn%d' % (i + 1)] = norm_layer(planes[i], **norm_kwargs)
87+
inplanes, planes[i], 3, stride=stride if i == 2 else 1, dilation=dilation, norm_layer=norm_layer)
88+
rep['bn%d' % (i + 1)] = norm_layer(planes[i])
9489
inplanes = planes[i]
9590

9691
if not start_with_relu:
@@ -115,74 +110,63 @@ class Xception65(nn.Module):
115110
"""
116111

117112
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, norm_layer=nn.BatchNorm2d,
118-
norm_kwargs=None, drop_rate=0., global_pool='avg'):
113+
drop_rate=0., global_pool='avg'):
119114
super(Xception65, self).__init__()
120115
self.num_classes = num_classes
121116
self.drop_rate = drop_rate
122-
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
123117
if output_stride == 32:
124118
entry_block3_stride = 2
125119
exit_block20_stride = 2
126-
middle_block_dilation = 1
127-
exit_block_dilations = (1, 1)
120+
middle_dilation = 1
121+
exit_dilation = (1, 1)
128122
elif output_stride == 16:
129123
entry_block3_stride = 2
130124
exit_block20_stride = 1
131-
middle_block_dilation = 1
132-
exit_block_dilations = (1, 2)
125+
middle_dilation = 1
126+
exit_dilation = (1, 2)
133127
elif output_stride == 8:
134128
entry_block3_stride = 1
135129
exit_block20_stride = 1
136-
middle_block_dilation = 2
137-
exit_block_dilations = (2, 4)
130+
middle_dilation = 2
131+
exit_dilation = (2, 4)
138132
else:
139133
raise NotImplementedError
140134

141135
# Entry flow
142136
self.conv1 = nn.Conv2d(in_chans, 32, kernel_size=3, stride=2, padding=1, bias=False)
143-
self.bn1 = norm_layer(num_features=32, **norm_kwargs)
137+
self.bn1 = norm_layer(num_features=32)
144138
self.act1 = nn.ReLU(inplace=True)
145139

146140
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1, bias=False)
147141
self.bn2 = norm_layer(num_features=64)
148142
self.act2 = nn.ReLU(inplace=True)
149143

150-
self.block1 = Block(
151-
64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
144+
self.block1 = Block(64, 128, stride=2, start_with_relu=False, norm_layer=norm_layer)
152145
self.block1_act = nn.ReLU(inplace=True)
153-
self.block2 = Block(
154-
128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
155-
self.block3 = Block(
156-
256, 728, stride=entry_block3_stride, norm_layer=norm_layer, norm_kwargs=norm_kwargs)
146+
self.block2 = Block(128, 256, stride=2, start_with_relu=False, norm_layer=norm_layer)
147+
self.block3 = Block(256, 728, stride=entry_block3_stride, norm_layer=norm_layer)
157148

158149
# Middle flow
159150
self.mid = nn.Sequential(OrderedDict([('block%d' % i, Block(
160-
728, 728, stride=1, dilation=middle_block_dilation,
161-
norm_layer=norm_layer, norm_kwargs=norm_kwargs)) for i in range(4, 20)]))
151+
728, 728, stride=1, dilation=middle_dilation, norm_layer=norm_layer)) for i in range(4, 20)]))
162152

163153
# Exit flow
164154
self.block20 = Block(
165-
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_block_dilations[0],
166-
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
155+
728, (728, 1024, 1024), stride=exit_block20_stride, dilation=exit_dilation[0], norm_layer=norm_layer)
167156
self.block20_act = nn.ReLU(inplace=True)
168157

169-
self.conv3 = SeparableConv2d(
170-
1024, 1536, 3, stride=1, dilation=exit_block_dilations[1],
171-
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
172-
self.bn3 = norm_layer(num_features=1536, **norm_kwargs)
158+
self.conv3 = SeparableConv2d(1024, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
159+
self.bn3 = norm_layer(num_features=1536)
173160
self.act3 = nn.ReLU(inplace=True)
174161

175-
self.conv4 = SeparableConv2d(
176-
1536, 1536, 3, stride=1, dilation=exit_block_dilations[1],
177-
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
178-
self.bn4 = norm_layer(num_features=1536, **norm_kwargs)
162+
self.conv4 = SeparableConv2d(1536, 1536, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
163+
self.bn4 = norm_layer(num_features=1536)
179164
self.act4 = nn.ReLU(inplace=True)
180165

181166
self.num_features = 2048
182167
self.conv5 = SeparableConv2d(
183-
1536, self.num_features, 3, stride=1, dilation=exit_block_dilations[1],
184-
norm_layer=norm_layer, norm_kwargs=norm_kwargs)
185-
self.bn5 = norm_layer(num_features=self.num_features, **norm_kwargs)
168+
1536, self.num_features, 3, stride=1, dilation=exit_dilation[1], norm_layer=norm_layer)
169+
self.bn5 = norm_layer(num_features=self.num_features)
186170
self.act5 = nn.ReLU(inplace=True)
187171
self.feature_info = [
188172
dict(num_chs=64, reduction=2, module='act2'),

timm/models/helpers.py

Lines changed: 49 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,31 @@ def load_custom_pretrained(model, cfg=None, load_fn=None, progress=False, check_
148148
_logger.warning("Valid function to load pretrained weights is not available, using random initialization.")
149149

150150

151+
def adapt_input_conv(in_chans, conv_weight):
152+
conv_type = conv_weight.dtype
153+
conv_weight = conv_weight.float() # Some weights are in torch.half, ensure it's float for sum on CPU
154+
O, I, J, K = conv_weight.shape
155+
if in_chans == 1:
156+
if I > 3:
157+
assert conv_weight.shape[1] % 3 == 0
158+
# For models with space2depth stems
159+
conv_weight = conv_weight.reshape(O, I // 3, 3, J, K)
160+
conv_weight = conv_weight.sum(dim=2, keepdim=False)
161+
else:
162+
conv_weight = conv_weight.sum(dim=1, keepdim=True)
163+
elif in_chans != 3:
164+
if I != 3:
165+
raise NotImplementedError('Weight format not supported by conversion.')
166+
else:
167+
# NOTE this strategy should be better than random init, but there could be other combinations of
168+
# the original RGB input layer weights that'd work better for specific cases.
169+
repeat = int(math.ceil(in_chans / 3))
170+
conv_weight = conv_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
171+
conv_weight *= (3 / float(in_chans))
172+
conv_weight = conv_weight.to(conv_type)
173+
return conv_weight
174+
175+
151176
def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=None, strict=True, progress=False):
152177
if cfg is None:
153178
cfg = getattr(model, 'default_cfg')
@@ -159,56 +184,35 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
159184
if filter_fn is not None:
160185
state_dict = filter_fn(state_dict)
161186

162-
if in_chans == 1:
163-
conv1_name = cfg['first_conv']
164-
_logger.info('Converting first conv (%s) pretrained weights from 3 to 1 channel' % conv1_name)
165-
conv1_weight = state_dict[conv1_name + '.weight']
166-
# Some weights are in torch.half, ensure it's float for sum on CPU
167-
conv1_type = conv1_weight.dtype
168-
conv1_weight = conv1_weight.float()
169-
O, I, J, K = conv1_weight.shape
170-
if I > 3:
171-
assert conv1_weight.shape[1] % 3 == 0
172-
# For models with space2depth stems
173-
conv1_weight = conv1_weight.reshape(O, I // 3, 3, J, K)
174-
conv1_weight = conv1_weight.sum(dim=2, keepdim=False)
175-
else:
176-
conv1_weight = conv1_weight.sum(dim=1, keepdim=True)
177-
conv1_weight = conv1_weight.to(conv1_type)
178-
state_dict[conv1_name + '.weight'] = conv1_weight
179-
elif in_chans != 3:
180-
conv1_name = cfg['first_conv']
181-
conv1_weight = state_dict[conv1_name + '.weight']
182-
conv1_type = conv1_weight.dtype
183-
conv1_weight = conv1_weight.float()
184-
O, I, J, K = conv1_weight.shape
185-
if I != 3:
186-
_logger.warning('Deleting first conv (%s) from pretrained weights.' % conv1_name)
187-
del state_dict[conv1_name + '.weight']
188-
strict = False
189-
else:
190-
# NOTE this strategy should be better than random init, but there could be other combinations of
191-
# the original RGB input layer weights that'd work better for specific cases.
192-
_logger.info('Repeating first conv (%s) weights in channel dim.' % conv1_name)
193-
repeat = int(math.ceil(in_chans / 3))
194-
conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
195-
conv1_weight *= (3 / float(in_chans))
196-
conv1_weight = conv1_weight.to(conv1_type)
197-
state_dict[conv1_name + '.weight'] = conv1_weight
187+
input_convs = cfg.get('first_conv', None)
188+
if input_convs is not None:
189+
if isinstance(input_convs, str):
190+
input_convs = (input_convs,)
191+
for input_conv_name in input_convs:
192+
weight_name = input_conv_name + '.weight'
193+
try:
194+
state_dict[weight_name] = adapt_input_conv(in_chans, state_dict[weight_name])
195+
_logger.info(
196+
f'Converted input conv {input_conv_name} pretrained weights from 3 to {in_chans} channel(s)')
197+
except NotImplementedError as e:
198+
del state_dict[weight_name]
199+
strict = False
200+
_logger.warning(
201+
f'Unable to convert pretrained {input_conv_name} weights, using random init for this layer.')
198202

199203
classifier_name = cfg['classifier']
200-
if num_classes == 1000 and cfg['num_classes'] == 1001:
201-
# FIXME this special case is problematic as number of pretrained weight sources increases
202-
# special case for imagenet trained models with extra background class in pretrained weights
203-
classifier_weight = state_dict[classifier_name + '.weight']
204-
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
205-
classifier_bias = state_dict[classifier_name + '.bias']
206-
state_dict[classifier_name + '.bias'] = classifier_bias[1:]
207-
elif num_classes != cfg['num_classes']:
208-
# completely discard fully connected for all other differences between pretrained and created model
204+
label_offset = cfg.get('label_offset', 0)
205+
if num_classes != cfg['num_classes']:
206+
# completely discard fully connected if model num_classes doesn't match pretrained weights
209207
del state_dict[classifier_name + '.weight']
210208
del state_dict[classifier_name + '.bias']
211209
strict = False
210+
elif label_offset > 0:
211+
# special case for pretrained weights with an extra background class in pretrained weights
212+
classifier_weight = state_dict[classifier_name + '.weight']
213+
state_dict[classifier_name + '.weight'] = classifier_weight[label_offset:]
214+
classifier_bias = state_dict[classifier_name + '.bias']
215+
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
212216

213217
model.load_state_dict(state_dict, strict=strict)
214218

timm/models/inception_resnet_v2.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,20 @@
1717
# ported from http://download.tensorflow.org/models/inception_resnet_v2_2016_08_30.tar.gz
1818
'inception_resnet_v2': {
1919
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/inception_resnet_v2-940b1cd6.pth',
20-
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
20+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
2121
'crop_pct': 0.8975, 'interpolation': 'bicubic',
2222
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
2323
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
24+
'label_offset': 1, # 1001 classes in pretrained weights
2425
},
2526
# ported from http://download.tensorflow.org/models/ens_adv_inception_resnet_v2_2017_08_18.tar.gz
2627
'ens_adv_inception_resnet_v2': {
2728
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/ens_adv_inception_resnet_v2-2592a550.pth',
28-
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
29+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
2930
'crop_pct': 0.8975, 'interpolation': 'bicubic',
3031
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
3132
'first_conv': 'conv2d_1a.conv', 'classifier': 'classif',
33+
'label_offset': 1, # 1001 classes in pretrained weights
3234
}
3335
}
3436

@@ -222,7 +224,7 @@ def forward(self, x):
222224

223225

224226
class InceptionResnetV2(nn.Module):
225-
def __init__(self, num_classes=1001, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
227+
def __init__(self, num_classes=1000, in_chans=3, drop_rate=0., output_stride=32, global_pool='avg'):
226228
super(InceptionResnetV2, self).__init__()
227229
self.drop_rate = drop_rate
228230
self.num_classes = num_classes

timm/models/inception_v4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616
default_cfgs = {
1717
'inception_v4': {
1818
'url': 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-cadene/inceptionv4-8e4777a0.pth',
19-
'num_classes': 1001, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
19+
'num_classes': 1000, 'input_size': (3, 299, 299), 'pool_size': (8, 8),
2020
'crop_pct': 0.875, 'interpolation': 'bicubic',
2121
'mean': IMAGENET_INCEPTION_MEAN, 'std': IMAGENET_INCEPTION_STD,
2222
'first_conv': 'features.0.conv', 'classifier': 'last_linear',
23+
'label_offset': 1, # 1001 classes in pretrained weights
2324
}
2425
}
2526

@@ -241,7 +242,7 @@ def forward(self, x):
241242

242243

243244
class InceptionV4(nn.Module):
244-
def __init__(self, num_classes=1001, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
245+
def __init__(self, num_classes=1000, in_chans=3, output_stride=32, drop_rate=0., global_pool='avg'):
245246
super(InceptionV4, self).__init__()
246247
assert output_stride == 32
247248
self.drop_rate = drop_rate

timm/models/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .create_act import create_act_layer, get_act_layer, get_act_fn
1313
from .create_attn import get_attn, create_attn
1414
from .create_conv2d import create_conv2d
15-
from .create_norm_act import create_norm_act, get_norm_act_layer
15+
from .create_norm_act import get_norm_act_layer, create_norm_act, convert_norm_act
1616
from .drop import DropBlock2d, DropPath, drop_block_2d, drop_path
1717
from .eca import EcaModule, CecaModule
1818
from .evo_norm import EvoNormBatch2d, EvoNormSample2d

0 commit comments

Comments
 (0)