File tree Expand file tree Collapse file tree 3 files changed +18
-12
lines changed
tests/tests_pytorch/models Expand file tree Collapse file tree 3 files changed +18
-12
lines changed Original file line number Diff line number Diff line change 6161from lightning .pytorch .trainer .connectors .logger_connector .result import _get_default_dtype
6262from lightning .pytorch .utilities import GradClipAlgorithmType
6363from lightning .pytorch .utilities .exceptions import MisconfigurationException
64- from lightning .pytorch .utilities .imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
64+ from lightning .pytorch .utilities .imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 , _TORCH_GREATER_EQUAL_2_6
6565from lightning .pytorch .utilities .model_helpers import _restricted_classmethod
6666from lightning .pytorch .utilities .rank_zero import WarningCache , rank_zero_warn
6767from lightning .pytorch .utilities .signature_utils import is_param_in_hook_signature
7373 OptimizerLRScheduler ,
7474)
7575
76- if TYPE_CHECKING :
77- from torch .distributed .device_mesh import DeviceMesh
78-
7976_ONNX_AVAILABLE = RequirementCache ("onnx" )
8077_ONNXSCRIPT_AVAILABLE = RequirementCache ("onnxscript" )
8178
82- if TYPE_CHECKING and _ONNXSCRIPT_AVAILABLE :
83- from torch .onnx import ONNXProgram
79+ if TYPE_CHECKING :
80+ from torch .distributed .device_mesh import DeviceMesh
81+
82+ if _TORCH_GREATER_EQUAL_2_6 :
83+ from torch .onnx import ONNXProgram
84+ else :
85+ from torch .onnx ._internal .exporter import ONNXProgram
8486
8587warning_cache = WarningCache ()
8688log = logging .getLogger (__name__ )
Original file line number Diff line number Diff line change 1414"""General utilities."""
1515
1616import functools
17+ import operator
1718import sys
1819
19- from lightning_utilities .core .imports import RequirementCache , package_available
20+ from lightning_utilities .core .imports import RequirementCache , package_available , compare_version
2021
2122from lightning .pytorch .utilities .rank_zero import rank_zero_warn
2223
2324_PYTHON_GREATER_EQUAL_3_11_0 = (sys .version_info .major , sys .version_info .minor ) >= (3 , 11 )
25+ _TORCH_GREATER_EQUAL_2_6 = compare_version ("torch" , operator .ge , "2.6.0" )
2426_TORCHMETRICS_GREATER_EQUAL_0_8_0 = RequirementCache ("torchmetrics>=0.8.0" )
2527_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache ("torchmetrics>=0.9.1" )
2628_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache ("torchmetrics>=0.11.0" ) # using new API with task
Original file line number Diff line number Diff line change 2525from lightning_utilities import compare_version
2626
2727import tests_pytorch .helpers .pipelines as tpipes
28+ from lightning .pytorch .utilities .imports import _TORCH_GREATER_EQUAL_2_6
2829from lightning .pytorch import Trainer
2930from lightning .pytorch .core .module import _ONNXSCRIPT_AVAILABLE
3031from lightning .pytorch .demos .boring_classes import BoringModel
3132from tests_pytorch .helpers .runif import RunIf
3233from tests_pytorch .utilities .test_model_summary import UnorderedModel
3334
35+ if _TORCH_GREATER_EQUAL_2_6 :
36+ from torch .onnx import ONNXProgram
37+ else :
38+ from torch .onnx ._internal .exporter import ONNXProgram
39+
3440
3541@RunIf (onnx = True )
3642def test_model_saves_with_input_sample (tmp_path ):
@@ -207,14 +213,10 @@ def test_model_return_type():
207213 model .eval ()
208214
209215 onnx_pg = model .to_onnx (dynamo = True )
210-
211- onnx_cls = torch .onnx .ONNXProgram if torch .__version__ >= "2.6.0" else torch .onnx ._internal .exporter .ONNXProgram
212-
213- assert isinstance (onnx_pg , onnx_cls )
216+ assert isinstance (onnx_pg , ONNXProgram )
214217
215218 model_ret = model (model .example_input_array )
216219 inf_ret = onnx_pg (model .example_input_array )
217-
218220 assert torch .allclose (model_ret , inf_ret [0 ], rtol = 1e-03 , atol = 1e-05 )
219221
220222
You can’t perform that action at this time.
0 commit comments