Skip to content

Commit 0eea296

Browse files
committed
warn
1 parent dd19f16 commit 0eea296

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/tests_pytorch/utilities/test_compile.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
# limitations under the License.
1414
import os
1515
import sys
16+
from contextlib import nullcontext
1617
from unittest import mock
1718

1819
import pytest
1920
import torch
2021
from lightning_utilities.core.imports import RequirementCache
2122

22-
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2
23+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_2, _TORCH_GREATER_EQUAL_2_4
2324
from lightning.pytorch import LightningModule, Trainer
2425
from lightning.pytorch.demos.boring_classes import BoringModel
2526
from lightning.pytorch.utilities.compile import from_compiled, to_uncompiled
@@ -73,7 +74,13 @@ def test_trainer_compiled_model(_, tmp_path, monkeypatch, mps_count_0):
7374
mock_cuda_count(monkeypatch, 2)
7475

7576
# TODO: Update deepspeed to avoid deprecation warning for `torch.cuda.amp.custom_fwd` on import
76-
with pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated"):
77+
warn_context = (
78+
pytest.warns(FutureWarning, match="torch.cuda.amp.*is deprecated")
79+
if _TORCH_GREATER_EQUAL_2_4
80+
else nullcontext()
81+
)
82+
83+
with warn_context:
7784
trainer = Trainer(strategy="deepspeed", accelerator="cuda", **trainer_kwargs)
7885

7986
with pytest.raises(RuntimeError, match="Using a compiled model is incompatible with the current strategy.*"):

0 commit comments

Comments
 (0)