Skip to content

Commit a600f41

Browse files
committed
update
1 parent 0e9d81b commit a600f41

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

src/lightning/pytorch/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task
2929
_TORCHMETRICS_GREATER_EQUAL_1_0_0 = RequirementCache("torchmetrics>=1.0.0")
3030
_TORCH_EQUAL_2_8 = RequirementCache("torch>=2.8.0,<2.9.0")
31+
_TORCH_EQUAL_2_9 = RequirementCache("torch>=2.9.0,<2.10.0")
3132
_TORCH_GREATER_EQUAL_2_8 = compare_version("torch", operator.ge, "2.8.0")
3233

3334
_OMEGACONF_AVAILABLE = package_available("omegaconf")

tests/tests_pytorch/models/test_torch_tensorrt.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from lightning.pytorch.core.module import _TORCH_TRT_AVAILABLE
1111
from lightning.pytorch.demos.boring_classes import BoringModel
1212
from lightning.pytorch.utilities.exceptions import MisconfigurationException
13+
from lightning.pytorch.utilities.imports import _TORCH_EQUAL_2_9
1314
from tests_pytorch.helpers.runif import RunIf
1415

1516

@@ -110,7 +111,14 @@ def test_tensorrt_saves_on_multi_gpu(tmp_path):
110111
[
111112
("default", torch.fx.GraphModule),
112113
("dynamo", torch.fx.GraphModule),
113-
("ts", torch.jit.ScriptModule),
114+
pytest.param(
115+
"ts",
116+
torch.jit.ScriptModule,
117+
marks=pytest.mark.skipif(
118+
_TORCH_EQUAL_2_9,
119+
reason="TorchScript IR crashes with torch_tensorrt on PyTorch 2.9",
120+
),
121+
),
114122
],
115123
)
116124
@RunIf(tensorrt=True, min_cuda_gpus=1, min_torch="2.2.0")

0 commit comments

Comments
 (0)