Skip to content

Commit 42559fe

Browse files
committed
weights_only=False for torch>=2.6 in tests
1 parent d61b9ec commit 42559fe

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

tests/tests_pytorch/callbacks/test_callbacks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pytest
1919
from lightning_utilities.test.warning import no_warning_call
2020

21+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
2122
from lightning.pytorch import Callback, Trainer
2223
from lightning.pytorch.callbacks import ModelCheckpoint
2324
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -132,7 +133,8 @@ def test_resume_callback_state_saved_by_type_stateful(tmp_path):
132133

133134
callback = OldStatefulCallback(state=222)
134135
trainer = Trainer(default_root_dir=tmp_path, max_steps=2, callbacks=[callback])
135-
trainer.fit(model, ckpt_path=ckpt_path)
136+
weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None
137+
trainer.fit(model, ckpt_path=ckpt_path, weights_only=weights_only)
136138
assert callback.state == 111
137139

138140

tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torch.optim.swa_utils import SWALR
2626
from torch.utils.data import DataLoader
2727

28+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
2829
from lightning.pytorch import Trainer
2930
from lightning.pytorch.callbacks import StochasticWeightAveraging
3031
from lightning.pytorch.demos.boring_classes import BoringModel, RandomDataset, RandomIterableDataset
@@ -173,8 +174,9 @@ def train_with_swa(
173174
devices=devices,
174175
)
175176

177+
weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None
176178
with _backward_patch(trainer):
177-
trainer.fit(model)
179+
trainer.fit(model, weights_only=weights_only)
178180

179181
# check the model is the expected
180182
assert trainer.lightning_module == model
@@ -307,8 +309,9 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals
307309
}
308310
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
309311

312+
weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None
310313
with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"):
311-
trainer.fit(model)
314+
trainer.fit(model, weights_only=weights_only)
312315

313316
checkpoint_dir = Path(tmp_path) / "checkpoints"
314317
checkpoint_files = os.listdir(checkpoint_dir)
@@ -318,7 +321,7 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals
318321
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)
319322

320323
with _backward_patch(trainer):
321-
trainer.fit(resume_model, ckpt_path=ckpt_path)
324+
trainer.fit(resume_model, ckpt_path=ckpt_path, weights_only=weights_only)
322325

323326

324327
class CustomSchedulerModel(SwaTestModel):

tests/tests_pytorch/models/test_hparams.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from lightning_utilities.test.warning import no_warning_call
3131
from torch.utils.data import DataLoader
3232

33+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_6
3334
from lightning.pytorch import LightningModule, Trainer
3435
from lightning.pytorch.callbacks import ModelCheckpoint
3536
from lightning.pytorch.core.datamodule import LightningDataModule
@@ -748,8 +749,9 @@ def test_model_with_fsspec_as_parameter(tmp_path):
748749
trainer = Trainer(
749750
default_root_dir=tmp_path, limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1
750751
)
751-
trainer.fit(model)
752-
trainer.test()
752+
weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None
753+
trainer.fit(model, weights_only=weights_only)
754+
trainer.test(weights_only=weights_only)
753755

754756

755757
@pytest.mark.xfail(

0 commit comments

Comments
 (0)