@@ -138,6 +138,8 @@ def _cfg(url='', **kwargs):
138138 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_m-4647fc68.pth' ),
139139 'mixnet_l' : _cfg (
140140 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_l-5a9a2ed8.pth' ),
141+ 'mixnet_xl' : _cfg (),
142+ 'mixnet_xxl' : _cfg (),
141143 'tf_mixnet_s' : _cfg (
142144 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_mixnet_s-89d3354b.pth' ),
143145 'tf_mixnet_m' : _cfg (
@@ -312,21 +314,59 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
312314 else :
313315 assert False , 'Unknown block type (%s)' % block_type
314316
315- # return a list of block args expanded by num_repeat and
316- # scaled by depth_multiplier
317- num_repeat = int (math .ceil (num_repeat * depth_multiplier ))
318- return [deepcopy (block_args ) for _ in range (num_repeat )]
317+ return block_args , num_repeat
319318
320319
321- def _decode_arch_def (arch_def , depth_multiplier = 1.0 ):
320+ def _scale_stage_depth (stack_args , repeats , depth_multiplier = 1.0 , depth_trunc = 'ceil' ):
321+ """ Per-stage depth scaling
322+ Scales the block repeats in each stage. This depth scaling impl maintains
323+ compatibility with the EfficientNet scaling method, while allowing sensible
324+ scaling for other models that may have multiple block arg definitions in each stage.
325+ """
326+
327+ # We scale the total repeat count for each stage, there may be multiple
328+ # block arg defs per stage so we need to sum.
329+ num_repeat = sum (repeats )
330+ if depth_trunc == 'round' :
331+ # Truncating to int by rounding allows stages with few repeats to remain
332+ # proportionally smaller for longer. This is a good choice when stage definitions
333+ # include single repeat stages that we'd prefer to keep that way as long as possible
334+ num_repeat_scaled = max (1 , round (num_repeat * depth_multiplier ))
335+ else :
336+ # The default for EfficientNet truncates repeats to int via 'ceil'.
337+ # Any multiplier > 1.0 will result in an increased depth for every stage.
338+ num_repeat_scaled = int (math .ceil (num_repeat * depth_multiplier ))
339+
340+ # Proportionally distribute repeat count scaling to each block definition in the stage.
341+ # Allocation is done in reverse as it results in the first block being less likely to be scaled.
342+ # The first block makes less sense to repeat in most of the arch definitions.
343+ repeats_scaled = []
344+ for r in repeats [::- 1 ]:
345+ rs = max (1 , round ((r / num_repeat * num_repeat_scaled )))
346+ repeats_scaled .append (rs )
347+ num_repeat -= r
348+ num_repeat_scaled -= rs
349+ repeats_scaled = repeats_scaled [::- 1 ]
350+
351+ # Apply the calculated scaling to each block arg in the stage
352+ sa_scaled = []
353+ for ba , rep in zip (stack_args , repeats_scaled ):
354+ sa_scaled .extend ([deepcopy (ba ) for _ in range (rep )])
355+ return sa_scaled
356+
357+
358+ def _decode_arch_def (arch_def , depth_multiplier = 1.0 , depth_trunc = 'ceil' ):
322359 arch_args = []
323360 for stack_idx , block_strings in enumerate (arch_def ):
324361 assert isinstance (block_strings , list )
325362 stack_args = []
363+ repeats = []
326364 for block_str in block_strings :
327365 assert isinstance (block_str , str )
328- stack_args .extend (_decode_block_str (block_str , depth_multiplier ))
329- arch_args .append (stack_args )
366+ ba , rep = _decode_block_str (block_str )
367+ stack_args .append (ba )
368+ repeats .append (rep )
369+ arch_args .append (_scale_stage_depth (stack_args , repeats , depth_multiplier , depth_trunc ))
330370 return arch_args
331371
332372
@@ -1261,7 +1301,7 @@ def _gen_mixnet_s(channel_multiplier=1.0, num_classes=1000, **kwargs):
12611301 return model
12621302
12631303
1264- def _gen_mixnet_m (channel_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
1304+ def _gen_mixnet_m (channel_multiplier = 1.0 , depth_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
12651305 """Creates a MixNet Medium-Large model.
12661306
12671307 Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/mnasnet/mixnet
@@ -1283,7 +1323,7 @@ def _gen_mixnet_m(channel_multiplier=1.0, num_classes=1000, **kwargs):
12831323 # 7x7
12841324 ]
12851325 model = GenEfficientNet (
1286- _decode_arch_def (arch_def ),
1326+ _decode_arch_def (arch_def , depth_multiplier = depth_multiplier , depth_trunc = 'round' ),
12871327 num_classes = num_classes ,
12881328 stem_size = 24 ,
12891329 num_features = 1536 ,
@@ -1876,6 +1916,33 @@ def mixnet_l(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
18761916 return model
18771917
18781918
1919+ @register_model
1920+ def mixnet_xl (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1921+ """Creates a MixNet Extra-Large model.
1922+ """
1923+ default_cfg = default_cfgs ['mixnet_xl' ]
1924+ #kwargs['drop_connect_rate'] = 0.2
1925+ model = _gen_mixnet_m (
1926+ channel_multiplier = 1.6 , depth_multiplier = 1.2 , num_classes = num_classes , in_chans = in_chans , ** kwargs )
1927+ model .default_cfg = default_cfg
1928+ if pretrained :
1929+ load_pretrained (model , default_cfg , num_classes , in_chans )
1930+ return model
1931+
1932+
1933+ @register_model
1934+ def mixnet_xxl (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1935+ """Creates a MixNet Double Extra Large model.
1936+ """
1937+ default_cfg = default_cfgs ['mixnet_xxl' ]
1938+ model = _gen_mixnet_m (
1939+ channel_multiplier = 2.4 , depth_multiplier = 1.3 , num_classes = num_classes , in_chans = in_chans , ** kwargs )
1940+ model .default_cfg = default_cfg
1941+ if pretrained :
1942+ load_pretrained (model , default_cfg , num_classes , in_chans )
1943+ return model
1944+
1945+
18791946@register_model
18801947def tf_mixnet_s (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
18811948 """Creates a MixNet Small model. Tensorflow compatible variant
0 commit comments