@@ -47,6 +47,19 @@ def configure_optimizers(self) -> None:
4747 return torch .optim .SGD (self .layer .parameters (), lr = 0.1 )
4848
4949
50+ class LargeTestModel (BoringModel ):
51+ def __init__ (self ):
52+ super ().__init__ ()
53+ self .layer = None
54+
55+ def configure_model (self ):
56+ print ("XXX configure_model" )
57+ self .layer = nn .Sequential (nn .Linear (32 , 32 ), nn .ReLU (), nn .Linear (32 , 2 ))
58+
59+ def configure_optimizers (self ):
60+ return torch .optim .SGD (self .parameters (), lr = 0.01 )
61+
62+
5063class EMAAveragingFunction :
5164 """EMA averaging function.
5265
@@ -252,8 +265,26 @@ def test_swa(tmp_path):
252265 _train (model , dataset , tmp_path , SWATestCallback ())
253266
254267
268+ @pytest .mark .parametrize (
269+ ("strategy" , "accelerator" , "devices" ),
270+ [
271+ ("auto" , "cpu" , 1 ),
272+ pytest .param ("auto" , "gpu" , 1 , marks = RunIf (min_cuda_gpus = 1 )),
273+ pytest .param ("fsdp" , "gpu" , 1 , marks = RunIf (min_cuda_gpus = 1 )),
274+ pytest .param ("ddp" , "gpu" , 2 , marks = RunIf (min_cuda_gpus = 2 )),
275+ pytest .param ("fsdp" , "gpu" , 2 , marks = RunIf (min_cuda_gpus = 2 )),
276+ ],
277+ )
278+ def test_ema_configure_model (tmp_path , strategy , accelerator , devices ):
279+ model = LargeTestModel ()
280+ dataset = RandomDataset (32 , 32 )
281+ callback = EMATestCallback ()
282+ _train (model , dataset , tmp_path , callback , strategy = strategy , accelerator = accelerator , devices = devices )
283+ assert isinstance (callback ._average_model .module .layer , nn .Sequential )
284+
285+
255286def _train (
256- model : TestModel ,
287+ model : BoringModel ,
257288 dataset : Dataset ,
258289 tmp_path : str ,
259290 callback : WeightAveraging ,
@@ -262,7 +293,7 @@ def _train(
262293 devices : int = 1 ,
263294 checkpoint_path : Optional [str ] = None ,
264295 will_crash : bool = False ,
265- ) -> TestModel :
296+ ) -> None :
266297 deterministic = accelerator == "cpu"
267298 trainer = Trainer (
268299 accelerator = accelerator ,
0 commit comments