22from torchvision .ops .misc import FrozenBatchNorm2d
33
44import timm
5+ import pytest
56from timm .utils .model import freeze , unfreeze
7+ from timm .utils .model import ActivationStatsHook
8+ from timm .utils .model import extract_spp_stats
69
10+ from timm .utils .model import _freeze_unfreeze
11+ from timm .utils .model import avg_sq_ch_mean , avg_ch_var , avg_ch_var_residual
12+ from timm .utils .model import reparameterize_model
13+ from timm .utils .model import get_state_dict
714
815def test_freeze_unfreeze ():
916 model = timm .create_model ('resnet18' )
@@ -54,4 +61,132 @@ def test_freeze_unfreeze():
5461 freeze (model .layer1 [0 ], ['bn1' ])
5562 assert isinstance (model .layer1 [0 ].bn1 , FrozenBatchNorm2d )
5663 unfreeze (model .layer1 [0 ], ['bn1' ])
57- assert isinstance (model .layer1 [0 ].bn1 , BatchNorm2d )
64+ assert isinstance (model .layer1 [0 ].bn1 , BatchNorm2d )
65+
66+ def test_activation_stats_hook_validation ():
67+ model = timm .create_model ('resnet18' )
68+
69+ def test_hook (model , input , output ):
70+ return output .mean ().item ()
71+
72+ # Test error case with mismatched lengths
73+ with pytest .raises (ValueError , match = "Please provide `hook_fns` for each `hook_fn_locs`" ):
74+ ActivationStatsHook (
75+ model ,
76+ hook_fn_locs = ['layer1.0.conv1' , 'layer1.0.conv2' ],
77+ hook_fns = [test_hook ]
78+ )
79+
80+
81+ def test_extract_spp_stats ():
82+ model = timm .create_model ('resnet18' )
83+
84+ def test_hook (model , input , output ):
85+ return output .mean ().item ()
86+
87+ stats = extract_spp_stats (
88+ model ,
89+ hook_fn_locs = ['layer1.0.conv1' ],
90+ hook_fns = [test_hook ],
91+ input_shape = [2 , 3 , 32 , 32 ]
92+ )
93+
94+ assert isinstance (stats , dict )
95+ assert test_hook .__name__ in stats
96+ assert isinstance (stats [test_hook .__name__ ], list )
97+ assert len (stats [test_hook .__name__ ]) > 0
98+
99+ def test_freeze_unfreeze_bn_root ():
100+ import torch .nn as nn
101+ from timm .layers import BatchNormAct2d
102+
103+ # Create batch norm layers
104+ bn = nn .BatchNorm2d (10 )
105+ bn_act = BatchNormAct2d (10 )
106+
107+ # Test with BatchNorm2d as root
108+ with pytest .raises (AssertionError ):
109+ _freeze_unfreeze (bn , mode = "freeze" )
110+
111+ # Test with BatchNormAct2d as root
112+ with pytest .raises (AssertionError ):
113+ _freeze_unfreeze (bn_act , mode = "freeze" )
114+
115+
116+ def test_activation_stats_functions ():
117+ import torch
118+
119+ # Create sample input tensor [batch, channels, height, width]
120+ x = torch .randn (2 , 3 , 4 , 4 )
121+
122+ # Test avg_sq_ch_mean
123+ result1 = avg_sq_ch_mean (None , None , x )
124+ assert isinstance (result1 , float )
125+
126+ # Test avg_ch_var
127+ result2 = avg_ch_var (None , None , x )
128+ assert isinstance (result2 , float )
129+
130+ # Test avg_ch_var_residual
131+ result3 = avg_ch_var_residual (None , None , x )
132+ assert isinstance (result3 , float )
133+
134+
135+ def test_reparameterize_model ():
136+ import torch .nn as nn
137+
138+ class FusableModule (nn .Module ):
139+ def __init__ (self ):
140+ super ().__init__ ()
141+ self .conv = nn .Conv2d (3 , 3 , 1 )
142+
143+ def fuse (self ):
144+ return nn .Identity ()
145+
146+ class ModelWithFusable (nn .Module ):
147+ def __init__ (self ):
148+ super ().__init__ ()
149+ self .fusable = FusableModule ()
150+ self .normal = nn .Linear (10 , 10 )
151+
152+ model = ModelWithFusable ()
153+
154+ # Test with inplace=False (should create a copy)
155+ new_model = reparameterize_model (model , inplace = False )
156+ assert isinstance (new_model .fusable , nn .Identity )
157+ assert isinstance (model .fusable , FusableModule ) # Original unchanged
158+
159+ # Test with inplace=True
160+ reparameterize_model (model , inplace = True )
161+ assert isinstance (model .fusable , nn .Identity )
162+
163+
164+ def test_get_state_dict_custom_unwrap ():
165+ import torch .nn as nn
166+
167+ class CustomModel (nn .Module ):
168+ def __init__ (self ):
169+ super ().__init__ ()
170+ self .linear = nn .Linear (10 , 10 )
171+
172+ model = CustomModel ()
173+
174+ def custom_unwrap (m ):
175+ return m
176+
177+ state_dict = get_state_dict (model , unwrap_fn = custom_unwrap )
178+ assert 'linear.weight' in state_dict
179+ assert 'linear.bias' in state_dict
180+
181+
182+ def test_freeze_unfreeze_string_input ():
183+ model = timm .create_model ('resnet18' )
184+
185+ # Test with string input
186+ _freeze_unfreeze (model , 'layer1' , mode = 'freeze' )
187+ assert model .layer1 [0 ].conv1 .weight .requires_grad == False
188+
189+ # Test unfreezing with string input
190+ _freeze_unfreeze (model , 'layer1' , mode = 'unfreeze' )
191+ assert model .layer1 [0 ].conv1 .weight .requires_grad == True
192+
0 commit comments