2828)
2929from ._features import feature_take_indices
3030from ._features_fx import register_notrace_module
31- from ._manipulate import checkpoint_seq , checkpoint
31+ from ._manipulate import checkpoint_seq
3232from ._registry import generate_default_cfgs , register_model
3333
3434__all__ = ['MobileNetV5' , 'MobileNetV5Encoder' ]
3535
36+ _GELU = partial (nn .GELU , approximate = 'tanh' )
37+
3638
3739@register_notrace_module
3840class MobileNetV5MultiScaleFusionAdapter (nn .Module ):
@@ -68,7 +70,7 @@ def __init__(
6870 self .layer_scale_init_value = layer_scale_init_value
6971 self .noskip = noskip
7072
71- act_layer = act_layer or nn . GELU
73+ act_layer = act_layer or _GELU
7274 norm_layer = norm_layer or RmsNorm2d
7375 self .ffn = UniversalInvertedResidual (
7476 in_chs = self .in_channels ,
@@ -167,7 +169,7 @@ def __init__(
167169 global_pool: Type of pooling to use for global pooling features of the FC head.
168170 """
169171 super ().__init__ ()
170- act_layer = act_layer or nn . GELU
172+ act_layer = act_layer or _GELU
171173 norm_layer = get_norm_layer (norm_layer ) or RmsNorm2d
172174 norm_act_layer = get_norm_act_layer (norm_layer , act_layer )
173175 se_layer = se_layer or SqueezeExcite
@@ -410,7 +412,7 @@ def __init__(
410412 block_args : BlockArgs ,
411413 in_chans : int = 3 ,
412414 stem_size : int = 64 ,
413- stem_bias : bool = False ,
415+ stem_bias : bool = True ,
414416 fix_stem : bool = False ,
415417 pad_type : str = '' ,
416418 msfa_indices : Sequence [int ] = (- 2 , - 1 ),
@@ -426,7 +428,7 @@ def __init__(
426428 layer_scale_init_value : Optional [float ] = None ,
427429 ):
428430 super ().__init__ ()
429- act_layer = act_layer or nn . GELU
431+ act_layer = act_layer or _GELU
430432 norm_layer = get_norm_layer (norm_layer ) or RmsNorm2d
431433 se_layer = se_layer or SqueezeExcite
432434 self .num_classes = 0 # Exists to satisfy ._hub module APIs.
@@ -526,6 +528,7 @@ def forward_intermediates(
526528 feat_idx = 0 # stem is index 0
527529 x = self .conv_stem (x )
528530 if feat_idx in take_indices :
531+ print ("conv_stem is captured" )
529532 intermediates .append (x )
530533 if feat_idx in self .msfa_indices :
531534 msfa_intermediates .append (x )
@@ -537,9 +540,16 @@ def forward_intermediates(
537540
538541 for blk in blocks :
539542 feat_idx += 1
540- x = blk (x )
541- if feat_idx in take_indices :
542- intermediates .append (x )
543+ # DO NOT SUBMIT: Revert to only the else condition after verification.
544+ if isinstance (blk , nn .Sequential ):
545+ for subblk in blk :
546+ x = subblk (x )
547+ if feat_idx in take_indices :
548+ intermediates .append (x )
549+ else :
550+ x = blk (x )
551+ if feat_idx in take_indices :
552+ intermediates .append (x )
543553 if feat_idx in self .msfa_indices :
544554 msfa_intermediates .append (x )
545555
@@ -777,7 +787,7 @@ def _gen_mobilenet_v5(
777787 fix_stem = channel_multiplier < 1.0 ,
778788 round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
779789 norm_layer = RmsNorm2d ,
780- act_layer = nn . GELU ,
790+ act_layer = _GELU ,
781791 layer_scale_init_value = 1e-5 ,
782792 )
783793 model_kwargs = dict (model_kwargs , ** kwargs )
0 commit comments