Skip to content

Commit 3bef524

Browse files
committed
Finish with HRNet, weights and models updated. Improve consistency in model classifier/global pool treatment.
1 parent 3ceeedc commit 3bef524

19 files changed

+729
-769
lines changed

clean_checkpoint.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@
88
parser = argparse.ArgumentParser(description='PyTorch ImageNet Validation')
99
parser.add_argument('--checkpoint', default='', type=str, metavar='PATH',
1010
help='path to latest checkpoint (default: none)')
11-
parser.add_argument('--output', default='./cleaned.pth', type=str, metavar='PATH',
11+
parser.add_argument('--output', default='', type=str, metavar='PATH',
1212
help='output path')
1313
parser.add_argument('--use-ema', dest='use_ema', action='store_true',
1414
help='use ema version of weights if present')
1515

1616

17+
_TEMP_NAME = './_checkpoint.pth'
18+
19+
1720
def main():
1821
args = parser.parse_args()
1922

@@ -40,13 +43,18 @@ def main():
4043
new_state_dict[name] = v
4144
print("=> Loaded state_dict from '{}'".format(args.checkpoint))
4245

43-
torch.save(new_state_dict, args.output)
44-
with open(args.output, 'rb') as f:
46+
torch.save(new_state_dict, _TEMP_NAME)
47+
with open(_TEMP_NAME, 'rb') as f:
4548
sha_hash = hashlib.sha256(f.read()).hexdigest()
4649

47-
checkpoint_base = os.path.splitext(args.checkpoint)[0]
50+
if args.output:
51+
checkpoint_root, checkpoint_base = os.path.split(args.output)
52+
checkpoint_base = os.path.splitext(checkpoint_base)[0]
53+
else:
54+
checkpoint_root = ''
55+
checkpoint_base = os.path.splitext(args.checkpoint)[0]
4856
final_filename = '-'.join([checkpoint_base, sha_hash[:8]]) + '.pth'
49-
shutil.move(args.output, final_filename)
57+
shutil.move(_TEMP_NAME, os.path.join(checkpoint_root, final_filename))
5058
print("=> Saved state_dict to '{}, SHA256: {}'".format(final_filename, sha_hash))
5159
else:
5260
print("Error: Checkpoint ({}) doesn't exist".format(args.checkpoint))

results/results-inv2-matched-frequency.csv

Lines changed: 152 additions & 94 deletions
Large diffs are not rendered by default.

sotabench.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,17 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
294294
_entry('res2next50', 'Res2NeXt-50', '1904.01169'),
295295
_entry('dla60_res2net', 'Res2Net-DLA-60', '1904.01169'),
296296
_entry('dla60_res2next', 'Res2NeXt-DLA-60', '1904.01169'),
297+
298+
## HRNet official impl weights
299+
_entry('hrnet_w18_small', 'HRNet-W18-C-Small-V1', '1908.07919'),
300+
_entry('hrnet_w18_small_v2', 'HRNet-W18-C-Small-V2', '1908.07919'),
301+
_entry('hrnet_w18', 'HRNet-W18-C', '1908.07919'),
302+
_entry('hrnet_w30', 'HRNet-W30-C', '1908.07919'),
303+
_entry('hrnet_w32', 'HRNet-W32-C', '1908.07919'),
304+
_entry('hrnet_w40', 'HRNet-W40-C', '1908.07919'),
305+
_entry('hrnet_w44', 'HRNet-W44-C', '1908.07919'),
306+
_entry('hrnet_w48', 'HRNet-W48-C', '1908.07919'),
307+
_entry('hrnet_w64', 'HRNet-W64-C', '1908.07919'),
297308
]
298309

299310
for m in model_list:

timm/models/densenet.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .registry import register_model
1212
from .helpers import load_pretrained
13-
from .adaptive_avgmax_pool import select_adaptive_pool2d
13+
from .adaptive_avgmax_pool import SelectAdaptivePool2d
1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1515
import re
1616

@@ -88,8 +88,8 @@ class DenseNet(nn.Module):
8888
def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
8989
num_init_features=64, bn_size=4, drop_rate=0,
9090
num_classes=1000, in_chans=3, global_pool='avg'):
91-
self.global_pool = global_pool
9291
self.num_classes = num_classes
92+
self.drop_rate = drop_rate
9393
super(DenseNet, self).__init__()
9494

9595
# First convolution
@@ -117,32 +117,31 @@ def __init__(self, growth_rate=32, block_config=(6, 12, 24, 16),
117117
self.features.add_module('norm5', nn.BatchNorm2d(num_features))
118118

119119
# Linear layer
120-
self.classifier = nn.Linear(num_features, num_classes)
121-
122120
self.num_features = num_features
121+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
122+
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
123123

124124
def get_classifier(self):
125125
return self.classifier
126126

127127
def reset_classifier(self, num_classes, global_pool='avg'):
128-
self.global_pool = global_pool
129128
self.num_classes = num_classes
130-
del self.classifier
131-
if num_classes:
132-
self.classifier = nn.Linear(self.num_features, num_classes)
133-
else:
134-
self.classifier = None
129+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
130+
self.classifier = nn.Linear(
131+
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
135132

136-
def forward_features(self, x, pool=True):
133+
def forward_features(self, x):
137134
x = self.features(x)
138135
x = F.relu(x, inplace=True)
139-
if pool:
140-
x = select_adaptive_pool2d(x, self.global_pool)
141-
x = x.view(x.size(0), -1)
142136
return x
143137

144138
def forward(self, x):
145-
return self.classifier(self.forward_features(x, pool=True))
139+
x = self.forward_features(x)
140+
x = self.global_pool(x).flatten(1)
141+
if self.drop_rate > 0.:
142+
x = F.dropout(x, p=self.drop_rate, training=self.training)
143+
x = self.classifier(x)
144+
return x
146145

147146

148147
def _filter_pretrained(state_dict):

timm/models/dla.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,7 @@ def __init__(self, levels, channels, num_classes=1000, in_chans=3, cardinality=1
276276

277277
self.num_features = channels[-1]
278278
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
279-
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes,
280-
kernel_size=1, stride=1, padding=0, bias=True)
279+
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
281280

282281
for m in self.modules():
283282
if isinstance(m, nn.Conv2d):
@@ -302,33 +301,30 @@ def get_classifier(self):
302301
return self.fc
303302

304303
def reset_classifier(self, num_classes, global_pool='avg'):
305-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
306304
self.num_classes = num_classes
307-
del self.fc
305+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
308306
if num_classes:
309-
self.fc = nn.Linear(self.num_features * self.global_pool.feat_mult(), num_classes)
307+
self.fc = nn.Conv2d(self.num_features * self.global_pool.feat_mult(), num_classes, 1, bias=True)
310308
else:
311309
self.fc = None
312310

313-
def forward_features(self, x, pool=True):
311+
def forward_features(self, x):
314312
x = self.base_layer(x)
315313
x = self.level0(x)
316314
x = self.level1(x)
317315
x = self.level2(x)
318316
x = self.level3(x)
319317
x = self.level4(x)
320318
x = self.level5(x)
321-
if pool:
322-
x = self.global_pool(x)
323319
return x
324320

325321
def forward(self, x):
326322
x = self.forward_features(x)
323+
x = self.global_pool(x)
327324
if self.drop_rate > 0.:
328325
x = F.dropout(x, p=self.drop_rate, training=self.training)
329326
x = self.fc(x)
330-
x = x.flatten(1)
331-
return x
327+
return x.flatten(1)
332328

333329

334330
@register_model

timm/models/dpn.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from .registry import register_model
1818
from .helpers import load_pretrained
19-
from .adaptive_avgmax_pool import select_adaptive_pool2d
19+
from .adaptive_avgmax_pool import SelectAdaptivePool2d
2020
from timm.data import IMAGENET_DPN_MEAN, IMAGENET_DPN_STD
2121

2222

@@ -160,7 +160,6 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
160160
super(DPN, self).__init__()
161161
self.num_classes = num_classes
162162
self.drop_rate = drop_rate
163-
self.global_pool = global_pool
164163
self.b = b
165164
bw_factor = 1 if small else 4
166165

@@ -218,32 +217,32 @@ def __init__(self, small=False, num_init_features=64, k_r=96, groups=32,
218217
self.features = nn.Sequential(blocks)
219218

220219
# Using 1x1 conv for the FC layer to allow the extra pooling scheme
221-
self.classifier = nn.Conv2d(in_chs, num_classes, kernel_size=1, bias=True)
220+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
221+
self.classifier = nn.Conv2d(
222+
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True)
222223

223224
def get_classifier(self):
224225
return self.classifier
225226

226227
def reset_classifier(self, num_classes, global_pool='avg'):
227228
self.num_classes = num_classes
228-
self.global_pool = global_pool
229-
del self.classifier
229+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
230230
if num_classes:
231-
self.classifier = nn.Conv2d(self.num_features, num_classes, kernel_size=1, bias=True)
231+
self.classifier = nn.Conv2d(
232+
self.num_features * self.global_pool.feat_mult(), num_classes, kernel_size=1, bias=True)
232233
else:
233234
self.classifier = None
234235

235-
def forward_features(self, x, pool=True):
236-
x = self.features(x)
237-
if pool:
238-
x = select_adaptive_pool2d(x, pool_type=self.global_pool)
239-
return x
236+
def forward_features(self, x):
237+
return self.features(x)
240238

241239
def forward(self, x):
242240
x = self.forward_features(x)
241+
x = self.global_pool(x)
243242
if self.drop_rate > 0.:
244243
x = F.dropout(x, p=self.drop_rate, training=self.training)
245244
out = self.classifier(x)
246-
return out.view(out.size(0), -1)
245+
return out.flatten(1)
247246

248247

249248
@register_model

timm/models/efficientnet.py

Lines changed: 7 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -211,8 +211,7 @@ class EfficientNet(nn.Module):
211211
def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3, stem_size=32,
212212
channel_multiplier=1.0, channel_divisor=8, channel_min=None,
213213
pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
214-
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None,
215-
global_pool='avg', weight_init='goog'):
214+
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, global_pool='avg'):
216215
super(EfficientNet, self).__init__()
217216
norm_kwargs = norm_kwargs or {}
218217

@@ -245,11 +244,7 @@ def __init__(self, block_args, num_classes=1000, num_features=1280, in_chans=3,
245244
# Classifier
246245
self.classifier = nn.Linear(self.num_features * self.global_pool.feat_mult(), self.num_classes)
247246

248-
for m in self.modules():
249-
if weight_init == 'goog':
250-
efficientnet_init_goog(m)
251-
else:
252-
efficientnet_init_default(m)
247+
efficientnet_init_weights(self)
253248

254249
def as_sequential(self):
255250
layers = [self.conv_stem, self.bn1, self.act1]
@@ -262,14 +257,10 @@ def get_classifier(self):
262257
return self.classifier
263258

264259
def reset_classifier(self, num_classes, global_pool='avg'):
265-
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
266260
self.num_classes = num_classes
267-
del self.classifier
268-
if num_classes:
269-
self.classifier = nn.Linear(
270-
self.num_features * self.global_pool.feat_mult(), num_classes)
271-
else:
272-
self.classifier = None
261+
self.global_pool = SelectAdaptivePool2d(pool_type=global_pool)
262+
self.classifier = nn.Linear(
263+
self.num_features * self.global_pool.feat_mult(), num_classes) if num_classes else None
273264

274265
def forward_features(self, x):
275266
x = self.conv_stem(x)
@@ -300,7 +291,7 @@ class EfficientNetFeatures(nn.Module):
300291
def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pre_pwl',
301292
in_chans=3, stem_size=32, channel_multiplier=1.0, channel_divisor=8, channel_min=None,
302293
output_stride=32, pad_type='', act_layer=nn.ReLU, drop_rate=0., drop_connect_rate=0.,
303-
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None, weight_init='goog'):
294+
se_kwargs=None, norm_layer=nn.BatchNorm2d, norm_kwargs=None):
304295
super(EfficientNetFeatures, self).__init__()
305296
norm_kwargs = norm_kwargs or {}
306297

@@ -326,12 +317,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='pr
326317
self.feature_info = builder.features # builder provides info about feature channels for each block
327318
self._in_chs = builder.in_chs
328319

329-
for m in self.modules():
330-
if weight_init == 'goog':
331-
efficientnet_init_goog(m)
332-
else:
333-
efficientnet_init_default(m)
334-
320+
efficientnet_init_weights(self)
335321
if _DEBUG:
336322
for k, v in self.feature_info.items():
337323
print('Feature idx: {}: Name: {}, Channels: {}'.format(k, v['name'], v['num_chs']))

timm/models/efficientnet_builder.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -358,9 +358,13 @@ def __call__(self, in_chs, model_block_args):
358358
return stages
359359

360360

361-
def efficientnet_init_goog(m, n=''):
362-
# weight init as per Tensorflow Official impl
363-
# https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
361+
def _init_weight_goog(m, n=''):
362+
""" Weight initialization as per Tensorflow official implementations.
363+
364+
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
365+
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
366+
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
367+
"""
364368
if isinstance(m, CondConv2d):
365369
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
366370
init_weight_fn = get_condconv_initializer(
@@ -386,7 +390,8 @@ def efficientnet_init_goog(m, n=''):
386390
m.bias.data.zero_()
387391

388392

389-
def efficientnet_init_default(m, n=''):
393+
def _init_weight_default(m, n=''):
394+
""" Basic ResNet (Kaiming) style weight init"""
390395
if isinstance(m, CondConv2d):
391396
init_fn = get_condconv_initializer(partial(
392397
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
@@ -400,3 +405,8 @@ def efficientnet_init_default(m, n=''):
400405
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
401406

402407

408+
def efficientnet_init_weights(model: nn.Module, init_fn=None):
409+
init_fn = init_fn or _init_weight_goog
410+
for n, m in model.named_modules():
411+
init_fn(m, n)
412+

0 commit comments

Comments
 (0)