25
25
from torch .optim .swa_utils import SWALR
26
26
from torch .utils .data import DataLoader
27
27
28
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_6
28
29
from lightning .pytorch import Trainer
29
30
from lightning .pytorch .callbacks import StochasticWeightAveraging
30
31
from lightning .pytorch .demos .boring_classes import BoringModel , RandomDataset , RandomIterableDataset
@@ -173,8 +174,9 @@ def train_with_swa(
173
174
devices = devices ,
174
175
)
175
176
177
+ weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None
176
178
with _backward_patch (trainer ):
177
- trainer .fit (model )
179
+ trainer .fit (model , weights_only = weights_only )
178
180
179
181
# check the model is the expected
180
182
assert trainer .lightning_module == model
@@ -307,8 +309,9 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals
307
309
}
308
310
trainer = Trainer (callbacks = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 ), ** trainer_kwargs )
309
311
312
+ weights_only = False if _TORCH_GREATER_EQUAL_2_6 else None
310
313
with _backward_patch (trainer ), pytest .raises (Exception , match = "SWA crash test" ):
311
- trainer .fit (model )
314
+ trainer .fit (model , weights_only = weights_only )
312
315
313
316
checkpoint_dir = Path (tmp_path ) / "checkpoints"
314
317
checkpoint_files = os .listdir (checkpoint_dir )
@@ -318,7 +321,7 @@ def _swa_resume_training_from_checkpoint(tmp_path, model, resume_model, ddp=Fals
318
321
trainer = Trainer (callbacks = SwaTestCallback (swa_epoch_start = swa_start , swa_lrs = 0.1 ), ** trainer_kwargs )
319
322
320
323
with _backward_patch (trainer ):
321
- trainer .fit (resume_model , ckpt_path = ckpt_path )
324
+ trainer .fit (resume_model , ckpt_path = ckpt_path , weights_only = weights_only )
322
325
323
326
324
327
class CustomSchedulerModel (SwaTestModel ):
0 commit comments