Skip to content

Commit 0e9d81b

Browse files
committed
refactor: update torch version checks to use greater than or equal comparisons
1 parent f8efc0f commit 0e9d81b

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
2929
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0")
3030
_TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9.0")
31-
_TORCH_EQUAL_2_9 = RequirementCache("torch>=2.9.0,<2.10.0")
31+
_TORCH_GREATER_EQUAL_2_8 = compare_version("torch", operator.ge, "2.8.0")
3232

3333
_OMEGACONF_AVAILABLE = package_available("omegaconf")
3434
_TORCHVISION_AVAILABLE = RequirementCache("torchvision")

tests/tests_pytorch/helpers/runif.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import pytest
1515

1616
from lightning.fabric.utilities.imports import _IS_WINDOWS
17-
from lightning.pytorch.utilities.imports import _TORCH_EQUAL_2_8, _TORCH_EQUAL_2_9
17+
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_8
1818
from lightning.pytorch.utilities.testing import _runif_reasons
1919

2020

@@ -27,6 +27,6 @@ def RunIf(**kwargs):
2727
_xfail_gloo_windows = pytest.mark.xfail(
2828
RuntimeError,
2929
strict=True,
30-
condition=(_IS_WINDOWS and (_TORCH_EQUAL_2_8 or _TORCH_EQUAL_2_9)),
30+
condition=(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_8),
3131
reason="makeDeviceForHostname(): unsupported gloo device",
3232
)

tests/tests_pytorch/trainer/test_trainer.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
from lightning.pytorch.strategies.launchers import _MultiProcessingLauncher, _SubprocessScriptLauncher
5656
from lightning.pytorch.trainer.states import RunningStage, TrainerFn
5757
from lightning.pytorch.utilities.exceptions import MisconfigurationException
58-
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_EQUAL_2_8
58+
from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE, _TORCH_GREATER_EQUAL_2_8
5959
from tests_pytorch.conftest import mock_cuda_count, mock_mps_count
6060
from tests_pytorch.helpers.datamodules import ClassifDataModule
6161
from tests_pytorch.helpers.runif import RunIf
@@ -1730,7 +1730,12 @@ def test_exception_when_lightning_module_is_not_set_on_trainer(fn):
17301730

17311731
@RunIf(min_cuda_gpus=1)
17321732
# FixMe: the memory raises to 1024 from expected 512
1733-
@pytest.mark.xfail(AssertionError, strict=True, condition=_TORCH_EQUAL_2_8, reason="temporarily disabled for torch 2.8")
1733+
@pytest.mark.xfail(
1734+
AssertionError,
1735+
strict=True,
1736+
condition=_TORCH_GREATER_EQUAL_2_8,
1737+
reason="temporarily disabled for torch >= 2.8",
1738+
)
17341739
def test_multiple_trainer_constant_memory_allocated(tmp_path):
17351740
"""This tests ensures calling the trainer several times reset the memory back to 0."""
17361741

@@ -1750,7 +1755,6 @@ def on_train_epoch_start(self, trainer, *_):
17501755
def current_memory():
17511756
# before measuring the memory force release any leftover allocations, including CUDA tensors
17521757
gc.collect()
1753-
torch.cuda.empty_cache()
17541758
return torch.cuda.memory_allocated(0)
17551759

17561760
model = TestModel()

0 commit comments

Comments
 (0)