5555from lightning .pytorch .strategies .launchers import _MultiProcessingLauncher , _SubprocessScriptLauncher
5656from lightning .pytorch .trainer .states import RunningStage , TrainerFn
5757from lightning .pytorch .utilities .exceptions import MisconfigurationException
58- from lightning .pytorch .utilities .imports import _OMEGACONF_AVAILABLE , _TORCH_EQUAL_2_8
58+ from lightning .pytorch .utilities .imports import _OMEGACONF_AVAILABLE , _TORCH_GREATER_EQUAL_2_8
5959from tests_pytorch .conftest import mock_cuda_count , mock_mps_count
6060from tests_pytorch .helpers .datamodules import ClassifDataModule
6161from tests_pytorch .helpers .runif import RunIf
@@ -1730,7 +1730,12 @@ def test_exception_when_lightning_module_is_not_set_on_trainer(fn):
17301730
17311731@RunIf (min_cuda_gpus = 1 )
17321732# FixMe: the memory raises to 1024 from expected 512
1733- @pytest .mark .xfail (AssertionError , strict = True , condition = _TORCH_EQUAL_2_8 , reason = "temporarily disabled for torch 2.8" )
1733+ @pytest .mark .xfail (
1734+ AssertionError ,
1735+ strict = True ,
1736+ condition = _TORCH_GREATER_EQUAL_2_8 ,
1737+ reason = "temporarily disabled for torch >= 2.8" ,
1738+ )
17341739def test_multiple_trainer_constant_memory_allocated (tmp_path ):
17351740 """This tests ensures calling the trainer several times reset the memory back to 0."""
17361741
@@ -1750,7 +1755,6 @@ def on_train_epoch_start(self, trainer, *_):
17501755 def current_memory ():
17511756 # before measuring the memory force release any leftover allocations, including CUDA tensors
17521757 gc .collect ()
1753- torch .cuda .empty_cache ()
17541758 return torch .cuda .memory_allocated (0 )
17551759
17561760 model = TestModel ()
0 commit comments