Skip to content

Commit 87bdfd9

Browse files
committed
import
1 parent eeb1a13 commit 87bdfd9

File tree

3 files changed

+18
-12
lines changed

3 files changed

+18
-12
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype
6262
from lightning.pytorch.utilities import GradClipAlgorithmType
6363
from lightning.pytorch.utilities.exceptions import MisconfigurationException
64-
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
64+
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1, _TORCH_GREATER_EQUAL_2_6
6565
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
6666
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
6767
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
@@ -73,14 +73,16 @@
7373
OptimizerLRScheduler,
7474
)
7575

76-
if TYPE_CHECKING:
77-
from torch.distributed.device_mesh import DeviceMesh
78-
7976
_ONNX_AVAILABLE = RequirementCache("onnx")
8077
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")
8178

82-
if TYPE_CHECKING and _ONNXSCRIPT_AVAILABLE:
83-
from torch.onnx import ONNXProgram
79+
if TYPE_CHECKING:
80+
from torch.distributed.device_mesh import DeviceMesh
81+
82+
if _TORCH_GREATER_EQUAL_2_6:
83+
from torch.onnx import ONNXProgram
84+
else:
85+
from torch.onnx._internal.exporter import ONNXProgram
8486

8587
warning_cache = WarningCache()
8688
log = logging.getLogger(__name__)

src/lightning/pytorch/utilities/imports.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
"""General utilities."""
1515

1616
import functools
17+
import operator
1718
import sys
1819

19-
from lightning_utilities.core.imports import RequirementCache, package_available
20+
from lightning_utilities.core.imports import RequirementCache, package_available, compare_version
2021

2122
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
2223

2324
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
25+
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
2426
_TORCHMETRICS_GREATER_EQUAL_0_8_0 = RequirementCache("torchmetrics>=0.8.0")
2527
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2628
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task

tests/tests_pytorch/models/test_onnx.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,18 @@
2525
from lightning_utilities import compare_version
2626

2727
import tests_pytorch.helpers.pipelines as tpipes
28+
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6
2829
from lightning.pytorch import Trainer
2930
from lightning.pytorch.core.module import _ONNXSCRIPT_AVAILABLE
3031
from lightning.pytorch.demos.boring_classes import BoringModel
3132
from tests_pytorch.helpers.runif import RunIf
3233
from tests_pytorch.utilities.test_model_summary import UnorderedModel
3334

35+
if _TORCH_GREATER_EQUAL_2_6:
36+
from torch.onnx import ONNXProgram
37+
else:
38+
from torch.onnx._internal.exporter import ONNXProgram
39+
3440

3541
@RunIf(onnx=True)
3642
def test_model_saves_with_input_sample(tmp_path):
@@ -207,14 +213,10 @@ def test_model_return_type():
207213
model.eval()
208214

209215
onnx_pg = model.to_onnx(dynamo=True)
210-
211-
onnx_cls = torch.onnx.ONNXProgram if torch.__version__ >= "2.6.0" else torch.onnx._internal.exporter.ONNXProgram
212-
213-
assert isinstance(onnx_pg, onnx_cls)
216+
assert isinstance(onnx_pg, ONNXProgram)
214217

215218
model_ret = model(model.example_input_array)
216219
inf_ret = onnx_pg(model.example_input_array)
217-
218220
assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05)
219221

220222

0 commit comments

Comments
 (0)