Skip to content

Commit 51a2375

Browse files
committed
Experimenting with a custom MixNet-XL and MixNet-XXL definition
1 parent 9816ca3 commit 51a2375

File tree

1 file changed

+76
-9
lines changed

1 file changed

+76
-9
lines changed

timm/models/gen_efficientnet.py

Lines changed: 76 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,8 @@ def _cfg(url='', **kwargs):
138138
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth'),
139139
'mixnet_l': _cfg(
140140
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth'),
141+
'mixnet_xl': _cfg(),
142+
'mixnet_xxl': _cfg(),
141143
'tf_mixnet_s': _cfg(
142144
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth'),
143145
'tf_mixnet_m': _cfg(
@@ -312,21 +314,59 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
312314
else:
313315
assert False, 'Unknown block type (%s)' % block_type
314316

315-
# return a list of block args expanded by num_repeat and
316-
# scaled by depth_multiplier
317-
num_repeat = int(math.ceil(num_repeat * depth_multiplier))
318-
return [deepcopy(block_args) for _ in range(num_repeat)]
317+
return block_args, num_repeat
319318

320319

321-
def _decode_arch_def(arch_def, depth_multiplier=1.0):
320+
def _scale_stage_depth(stack_args, repeats, depth_multiplier=1.0, depth_trunc='ceil'):
321+
""" Per-stage depth scaling
322+
Scales the block repeats in each stage. This depth scaling impl maintains
323+
compatibility with the EfficientNet scaling method, while allowing sensible
324+
scaling for other models that may have multiple block arg definitions in each stage.
325+
"""
326+
327+
# We scale the total repeat count for each stage, there may be multiple
328+
# block arg defs per stage so we need to sum.
329+
num_repeat = sum(repeats)
330+
if depth_trunc == 'round':
331+
# Truncating to int by rounding allows stages with few repeats to remain
332+
# proportionally smaller for longer. This is a good choice when stage definitions
333+
# include single repeat stages that we'd prefer to keep that way as long as possible
334+
num_repeat_scaled = max(1, round(num_repeat * depth_multiplier))
335+
else:
336+
# The default for EfficientNet truncates repeats to int via 'ceil'.
337+
# Any multiplier > 1.0 will result in an increased depth for every stage.
338+
num_repeat_scaled = int(math.ceil(num_repeat * depth_multiplier))
339+
340+
# Proportionally distribute repeat count scaling to each block definition in the stage.
341+
# Allocation is done in reverse as it results in the first block being less likely to be scaled.
342+
# The first block makes less sense to repeat in most of the arch definitions.
343+
repeats_scaled = []
344+
for r in repeats[::-1]:
345+
rs = max(1, round((r / num_repeat * num_repeat_scaled)))
346+
repeats_scaled.append(rs)
347+
num_repeat -= r
348+
num_repeat_scaled -= rs
349+
repeats_scaled = repeats_scaled[::-1]
350+
351+
# Apply the calculated scaling to each block arg in the stage
352+
sa_scaled = []
353+
for ba, rep in zip(stack_args, repeats_scaled):
354+
sa_scaled.extend([deepcopy(ba) for _ in range(rep)])
355+
return sa_scaled
356+
357+
358+
def _decode_arch_def(arch_def, depth_multiplier=1.0, depth_trunc='ceil'):
322359
arch_args = []
323360
for stack_idx, block_strings in enumerate(arch_def):
324361
assert isinstance(block_strings, list)
325362
stack_args = []
363+
repeats = []
326364
for block_str in block_strings:
327365
assert isinstance(block_str, str)
328-
stack_args.extend(_decode_block_str(block_str, depth_multiplier))
329-
arch_args.append(stack_args)
366+
ba, rep = _decode_block_str(block_str)
367+
stack_args.append(ba)
368+
repeats.append(rep)
369+
arch_args.append(_scale_stage_depth(stack_args, repeats, depth_multiplier, depth_trunc))
330370
return arch_args
331371

332372

@@ -1261,7 +1301,7 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
12611301
return model
12621302

12631303

1264-
def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
1304+
def _gen_mixnet_m(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=1000, **kwargs):
12651305
"""Creates a MixNet Medium-Large model.
12661306
12671307
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
@@ -1283,7 +1323,7 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
12831323
# 7x7
12841324
]
12851325
model = GenEfficientNet(
1286-
_decode_arch_def(arch_def),
1326+
_decode_arch_def(arch_def, depth_multiplier=depth_multiplier, depth_trunc='round'),
12871327
num_classes=num_classes,
12881328
stem_size=24,
12891329
num_features=1536,
@@ -1876,6 +1916,33 @@ def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
18761916
return model
18771917

18781918

1919+
@register_model
1920+
def mixnet_xl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1921+
"""Creates a MixNet Extra-Large model.
1922+
"""
1923+
default_cfg = default_cfgs['mixnet_xl']
1924+
#kwargs['drop_connect_rate'] = 0.2
1925+
model = _gen_mixnet_m(
1926+
channel_multiplier=1.6, depth_multiplier=1.2, num_classes=num_classes, in_chans=in_chans, **kwargs)
1927+
model.default_cfg = default_cfg
1928+
if pretrained:
1929+
load_pretrained(model, default_cfg, num_classes, in_chans)
1930+
return model
1931+
1932+
1933+
@register_model
1934+
def mixnet_xxl(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
1935+
"""Creates a MixNet Double Extra Large model.
1936+
"""
1937+
default_cfg = default_cfgs['mixnet_xxl']
1938+
model = _gen_mixnet_m(
1939+
channel_multiplier=2.4, depth_multiplier=1.3, num_classes=num_classes, in_chans=in_chans, **kwargs)
1940+
model.default_cfg = default_cfg
1941+
if pretrained:
1942+
load_pretrained(model, default_cfg, num_classes, in_chans)
1943+
return model
1944+
1945+
18791946
@register_model
18801947
def tf_mixnet_s(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
18811948
"""Creates a MixNet Small model. Tensorflow compatible variant

0 commit comments

Comments
 (0)