Skip to content

Commit 59186ec

Browse files
committed
add testing
1 parent c63e9f7 commit 59186ec

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,5 +387,35 @@ def test_misconfiguration_error_with_sharded_model(tmp_path, strategy: str):
387387
trainer.fit(model)
388388

389389

390+
def test_swa_with_infinite_epochs_and_batchnorm(tmp_path):
391+
"""Test that SWA works correctly with max_epochs=-1 (infinite training) and BatchNorm."""
392+
model = SwaTestModel(batchnorm=True)
393+
swa_callback = StochasticWeightAveraging(swa_lrs=0.1, swa_epoch_start=2)
394+
395+
trainer = Trainer(
396+
default_root_dir=tmp_path,
397+
enable_progress_bar=False,
398+
enable_model_summary=False,
399+
max_epochs=-1,
400+
max_steps=30, # Use max_steps as stopping condition
401+
limit_train_batches=5,
402+
limit_val_batches=0,
403+
callbacks=[swa_callback],
404+
logger=False,
405+
)
406+
assert trainer.max_epochs == -1
407+
assert trainer.fit_loop.max_epochs == -1
408+
409+
trainer.fit(model)
410+
assert trainer.current_epoch >= 5
411+
assert trainer.global_step == 30
412+
assert trainer.max_epochs == -1
413+
414+
# Verify SWA was actually applied (update_parameters should have been called)
415+
# SWA starts at epoch 2, so with 6 epochs (0-5), we should have 4 updates (epochs 2, 3, 4, 5)
416+
assert swa_callback.n_averaged is not None
417+
assert swa_callback.n_averaged > 0, "SWA should have updated parameters"
418+
419+
390420
def _backward_patch(trainer: Trainer) -> AbstractContextManager:
391421
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)

0 commit comments

Comments
 (0)