@@ -88,6 +88,14 @@ def _cfg(url='', **kwargs):
8888 url = '' , input_size = (3 , 528 , 528 ), pool_size = (17 , 17 ), crop_pct = 0.942 ),
8989 'efficientnet_b7' : _cfg (
9090 url = '' , input_size = (3 , 600 , 600 ), pool_size = (19 , 19 ), crop_pct = 0.949 ),
91+ 'efficientnet_es' : _cfg (
92+ url = '' ),
93+ 'efficientnet_em' : _cfg (
94+ url = '' ,
95+ input_size = (3 , 240 , 240 ), pool_size = (8 , 8 ), crop_pct = 0.882 ),
96+ 'efficientnet_el' : _cfg (
97+ url = '' ,
98+ input_size = (3 , 300 , 300 ), pool_size = (10 , 10 ), crop_pct = 0.904 ),
9199 'tf_efficientnet_b0' : _cfg (
92100 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b0_aa-827b6e33.pth' ,
93101 input_size = (3 , 224 , 224 )),
@@ -112,6 +120,18 @@ def _cfg(url='', **kwargs):
112120 'tf_efficientnet_b7' : _cfg (
113121 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b7_aa-076e3472.pth' ,
114122 input_size = (3 , 600 , 600 ), pool_size = (19 , 19 ), crop_pct = 0.949 ),
123+ 'tf_efficientnet_es' : _cfg (
124+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_es-ca1afbfe.pth' ,
125+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
126+ input_size = (3 , 224 , 224 ), ),
127+ 'tf_efficientnet_em' : _cfg (
128+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_em-e78cfe58.pth' ,
129+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
130+ input_size = (3 , 240 , 240 ), pool_size = (8 , 8 ), crop_pct = 0.882 ),
131+ 'tf_efficientnet_el' : _cfg (
132+ url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_el-5143854e.pth' ,
133+ mean = (0.5 , 0.5 , 0.5 ), std = (0.5 , 0.5 , 0.5 ),
134+ input_size = (3 , 300 , 300 ), pool_size = (10 , 10 ), crop_pct = 0.904 ),
115135 'mixnet_s' : _cfg (
116136 url = 'https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mixnet_s-a907afbc.pth' ),
117137 'mixnet_m' : _cfg (
@@ -239,6 +259,7 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
239259 act_fn = options ['n' ] if 'n' in options else None
240260 exp_kernel_size = _parse_ksize (options ['a' ]) if 'a' in options else 1
241261 pw_kernel_size = _parse_ksize (options ['p' ]) if 'p' in options else 1
262+ fake_in_chs = int (options ['fc' ]) if 'fc' in options else 0 # FIXME hack to deal with in_chs issue in TPU def
242263
243264 num_repeat = int (options ['r' ])
244265 # each type of block has different valid arguments, fill accordingly
@@ -267,6 +288,19 @@ def _decode_block_str(block_str, depth_multiplier=1.0):
267288 pw_act = block_type == 'dsa' ,
268289 noskip = block_type == 'dsa' or noskip ,
269290 )
291+ elif block_type == 'er' :
292+ block_args = dict (
293+ block_type = block_type ,
294+ exp_kernel_size = _parse_ksize (options ['k' ]),
295+ pw_kernel_size = pw_kernel_size ,
296+ out_chs = int (options ['c' ]),
297+ exp_ratio = float (options ['e' ]),
298+ fake_in_chs = fake_in_chs ,
299+ se_ratio = float (options ['se' ]) if 'se' in options else None ,
300+ stride = int (options ['s' ]),
301+ act_fn = act_fn ,
302+ noskip = noskip ,
303+ )
270304 elif block_type == 'cn' :
271305 block_args = dict (
272306 block_type = block_type ,
@@ -356,6 +390,9 @@ def _make_block(self, ba):
356390 bt = ba .pop ('block_type' )
357391 ba ['in_chs' ] = self .in_chs
358392 ba ['out_chs' ] = self ._round_channels (ba ['out_chs' ])
393+ if 'fake_in_chs' in ba and ba ['fake_in_chs' ]:
394+ # FIXME this is a hack to work around mismatch in origin impl input filters
395+ ba ['fake_in_chs' ] = self ._round_channels (ba ['fake_in_chs' ])
359396 ba ['bn_args' ] = self .bn_args
360397 ba ['pad_type' ] = self .pad_type
361398 # block act fn overrides the model default
@@ -373,6 +410,13 @@ def _make_block(self, ba):
373410 if self .verbose :
374411 logging .info (' DepthwiseSeparable {}, Args: {}' .format (self .block_idx , str (ba )))
375412 block = DepthwiseSeparableConv (** ba )
413+ elif bt == 'er' :
414+ ba ['drop_connect_rate' ] = self .drop_connect_rate * self .block_idx / self .block_count
415+ ba ['se_gate_fn' ] = self .se_gate_fn
416+ ba ['se_reduce_mid' ] = self .se_reduce_mid
417+ if self .verbose :
418+ logging .info (' EdgeResidual {}, Args: {}' .format (self .block_idx , str (ba )))
419+ block = EdgeResidual (** ba )
376420 elif bt == 'cn' :
377421 if self .verbose :
378422 logging .info (' ConvBnAct {}, Args: {}' .format (self .block_idx , str (ba )))
@@ -519,10 +563,62 @@ def forward(self, x):
519563 return x
520564
521565
566+ class EdgeResidual (nn .Module ):
567+ """ Residual block with expansion convolution followed by pointwise-linear w/ stride"""
568+
569+ def __init__ (self , in_chs , out_chs , exp_kernel_size = 3 , exp_ratio = 1.0 , fake_in_chs = 0 ,
570+ stride = 1 , pad_type = '' , act_fn = F .relu , noskip = False , pw_kernel_size = 1 ,
571+ se_ratio = 0. , se_reduce_mid = False , se_gate_fn = sigmoid ,
572+ bn_args = _BN_ARGS_PT , drop_connect_rate = 0. ):
573+ super (EdgeResidual , self ).__init__ ()
574+ mid_chs = int (fake_in_chs * exp_ratio ) if fake_in_chs > 0 else int (in_chs * exp_ratio )
575+ self .has_se = se_ratio is not None and se_ratio > 0.
576+ self .has_residual = (in_chs == out_chs and stride == 1 ) and not noskip
577+ self .act_fn = act_fn
578+ self .drop_connect_rate = drop_connect_rate
579+
580+ # Expansion convolution
581+ self .conv_exp = select_conv2d (in_chs , mid_chs , exp_kernel_size , padding = pad_type )
582+ self .bn1 = nn .BatchNorm2d (mid_chs , ** bn_args )
583+
584+ # Squeeze-and-excitation
585+ if self .has_se :
586+ se_base_chs = mid_chs if se_reduce_mid else in_chs
587+ self .se = SqueezeExcite (
588+ mid_chs , reduce_chs = max (1 , int (se_base_chs * se_ratio )), act_fn = act_fn , gate_fn = se_gate_fn )
589+
590+ # Point-wise linear projection
591+ self .conv_pwl = select_conv2d (mid_chs , out_chs , pw_kernel_size , stride = stride , padding = pad_type )
592+ self .bn2 = nn .BatchNorm2d (out_chs , ** bn_args )
593+
594+ def forward (self , x ):
595+ residual = x
596+
597+ # Expansion convolution
598+ x = self .conv_exp (x )
599+ x = self .bn1 (x )
600+ x = self .act_fn (x , inplace = True )
601+
602+ # Squeeze-and-excitation
603+ if self .has_se :
604+ x = self .se (x )
605+
606+ # Point-wise linear projection
607+ x = self .conv_pwl (x )
608+ x = self .bn2 (x )
609+
610+ if self .has_residual :
611+ if self .drop_connect_rate > 0. :
612+ x = drop_connect (x , self .training , self .drop_connect_rate )
613+ x += residual
614+
615+ return x
616+
617+
522618class DepthwiseSeparableConv (nn .Module ):
523619 """ DepthwiseSeparable block
524620 Used for DS convs in MobileNet-V1 and in the place of IR blocks with an expansion
525- factor of 1.0. This is an alternative to having a IR with optional first pw conv.
621+ factor of 1.0. This is an alternative to having a IR with an optional first pw conv.
526622 """
527623 def __init__ (self , in_chs , out_chs , dw_kernel_size = 3 ,
528624 stride = 1 , pad_type = '' , act_fn = F .relu , noskip = False ,
@@ -1092,7 +1188,6 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
10921188 ['ir_r4_k5_s2_e6_c192_se0.25' ],
10931189 ['ir_r1_k3_s1_e6_c320_se0.25' ],
10941190 ]
1095- # NOTE: other models in the family didn't scale the feature count
10961191 num_features = _round_channels (1280 , channel_multiplier , 8 , None )
10971192 model = GenEfficientNet (
10981193 _decode_arch_def (arch_def , depth_multiplier ),
@@ -1107,6 +1202,31 @@ def _gen_efficientnet(channel_multiplier=1.0, depth_multiplier=1.0, num_classes=
11071202 return model
11081203
11091204
1205+ def _gen_efficientnet_edge (channel_multiplier = 1.0 , depth_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
1206+ arch_def = [
1207+ # NOTE `fc` is present to override a mismatch between stem channels and in chs not
1208+ # present in other models
1209+ ['er_r1_k3_s1_e4_c24_fc24_noskip' ],
1210+ ['er_r2_k3_s2_e8_c32' ],
1211+ ['er_r4_k3_s2_e8_c48' ],
1212+ ['ir_r5_k5_s2_e8_c96' ],
1213+ ['ir_r4_k5_s1_e8_c144' ],
1214+ ['ir_r2_k5_s2_e8_c192' ],
1215+ ]
1216+ num_features = _round_channels (1280 , channel_multiplier , 8 , None )
1217+ model = GenEfficientNet (
1218+ _decode_arch_def (arch_def , depth_multiplier ),
1219+ num_classes = num_classes ,
1220+ stem_size = 32 ,
1221+ channel_multiplier = channel_multiplier ,
1222+ num_features = num_features ,
1223+ bn_args = _resolve_bn_args (kwargs ),
1224+ act_fn = F .relu ,
1225+ ** kwargs
1226+ )
1227+ return model
1228+
1229+
11101230def _gen_mixnet_s (channel_multiplier = 1.0 , num_classes = 1000 , ** kwargs ):
11111231 """Creates a MixNet Small model.
11121232
@@ -1481,7 +1601,6 @@ def efficientnet_b5(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
14811601 return model
14821602
14831603
1484-
14851604@register_model
14861605def efficientnet_b6 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
14871606 """ EfficientNet-B6 """
@@ -1512,6 +1631,45 @@ def efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
15121631 return model
15131632
15141633
1634+ @register_model
1635+ def efficientnet_es (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1636+ """ EfficientNet-Edge Small. """
1637+ default_cfg = default_cfgs ['efficientnet_es' ]
1638+ model = _gen_efficientnet_edge (
1639+ channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
1640+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1641+ model .default_cfg = default_cfg
1642+ if pretrained :
1643+ load_pretrained (model , default_cfg , num_classes , in_chans )
1644+ return model
1645+
1646+
1647+ @register_model
1648+ def efficientnet_em (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1649+ """ EfficientNet-Edge-Medium. """
1650+ default_cfg = default_cfgs ['efficientnet_em' ]
1651+ model = _gen_efficientnet_edge (
1652+ channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
1653+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1654+ model .default_cfg = default_cfg
1655+ if pretrained :
1656+ load_pretrained (model , default_cfg , num_classes , in_chans )
1657+ return model
1658+
1659+
1660+ @register_model
1661+ def efficientnet_el (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1662+ """ EfficientNet-Edge-Large. """
1663+ default_cfg = default_cfgs ['efficientnet_el' ]
1664+ model = _gen_efficientnet_edge (
1665+ channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
1666+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1667+ model .default_cfg = default_cfg
1668+ if pretrained :
1669+ load_pretrained (model , default_cfg , num_classes , in_chans )
1670+ return model
1671+
1672+
15151673@register_model
15161674def tf_efficientnet_b0 (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
15171675 """ EfficientNet-B0. Tensorflow compatible variant """
@@ -1634,6 +1792,51 @@ def tf_efficientnet_b7(pretrained=False, num_classes=1000, in_chans=3, **kwargs)
16341792 return model
16351793
16361794
1795+ @register_model
1796+ def tf_efficientnet_es (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1797+ """ EfficientNet-Edge Small. Tensorflow compatible variant """
1798+ default_cfg = default_cfgs ['tf_efficientnet_es' ]
1799+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1800+ kwargs ['pad_type' ] = 'same'
1801+ model = _gen_efficientnet_edge (
1802+ channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
1803+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1804+ model .default_cfg = default_cfg
1805+ if pretrained :
1806+ load_pretrained (model , default_cfg , num_classes , in_chans )
1807+ return model
1808+
1809+
1810+ @register_model
1811+ def tf_efficientnet_em (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1812+ """ EfficientNet-Edge-Medium. Tensorflow compatible variant """
1813+ default_cfg = default_cfgs ['tf_efficientnet_em' ]
1814+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1815+ kwargs ['pad_type' ] = 'same'
1816+ model = _gen_efficientnet_edge (
1817+ channel_multiplier = 1.0 , depth_multiplier = 1.1 ,
1818+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1819+ model .default_cfg = default_cfg
1820+ if pretrained :
1821+ load_pretrained (model , default_cfg , num_classes , in_chans )
1822+ return model
1823+
1824+
1825+ @register_model
1826+ def tf_efficientnet_el (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
1827+ """ EfficientNet-Edge-Large. Tensorflow compatible variant """
1828+ default_cfg = default_cfgs ['tf_efficientnet_el' ]
1829+ kwargs ['bn_eps' ] = _BN_EPS_TF_DEFAULT
1830+ kwargs ['pad_type' ] = 'same'
1831+ model = _gen_efficientnet_edge (
1832+ channel_multiplier = 1.2 , depth_multiplier = 1.4 ,
1833+ num_classes = num_classes , in_chans = in_chans , ** kwargs )
1834+ model .default_cfg = default_cfg
1835+ if pretrained :
1836+ load_pretrained (model , default_cfg , num_classes , in_chans )
1837+ return model
1838+
1839+
16371840@register_model
16381841def mixnet_s (pretrained = False , num_classes = 1000 , in_chans = 3 , ** kwargs ):
16391842 """Creates a MixNet Small model.
0 commit comments