@@ -387,5 +387,35 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str):
387387 trainer .fit (model )
388388
389389
390+ def test_swa_with_infinite_epochs_and_batchnorm (tmp_path ):
391+ """Test that SWA works correctly with max_epochs=-1 (infinite training) and BatchNorm."""
392+ model = SwaTestModel (batchnorm = True )
393+ swa_callback = StochasticWeightAveraging (swa_lrs = 0.1 , swa_epoch_start = 2 )
394+
395+ trainer = Trainer (
396+ default_root_dir = tmp_path ,
397+ enable_progress_bar = False ,
398+ enable_model_summary = False ,
399+ max_epochs = - 1 ,
400+ max_steps = 30 , # Use max_steps as stopping condition
401+ limit_train_batches = 5 ,
402+ limit_val_batches = 0 ,
403+ callbacks = [swa_callback ],
404+ logger = False ,
405+ )
406+ assert trainer .max_epochs == - 1
407+ assert trainer .fit_loop .max_epochs == - 1
408+
409+ trainer .fit (model )
410+ assert trainer .current_epoch >= 5
411+ assert trainer .global_step == 30
412+ assert trainer .max_epochs == - 1
413+
414+ # Verify SWA was actually applied (update_parameters should have been called)
415+ # SWA starts at epoch 2, so with 6 epochs (0-5), we should have 4 updates (epochs 2, 3, 4, 5)
416+ assert swa_callback .n_averaged is not None
417+ assert swa_callback .n_averaged > 0 , "SWA should have updated parameters"
418+
419+
390420def _backward_patch (trainer : Trainer ) -> AbstractContextManager :
391421 return mock .patch .object (Strategy , "backward" , wraps = trainer .strategy .backward )
0 commit comments