diff --git a/pyproject.toml b/pyproject.toml index b45f60489c6fe..a63da5f246392 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,6 +180,7 @@ markers = [ ] filterwarnings = [ "error::FutureWarning", + "ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated ] xfail_strict = true junit_duration_report = "call" diff --git a/requirements/fabric/base.txt b/requirements/fabric/base.txt index 6b55dc93bb089..b905e18f9f5d6 100644 --- a/requirements/fabric/base.txt +++ b/requirements/fabric/base.txt @@ -4,5 +4,5 @@ torch >=2.1.0, <2.8.0 fsspec[http] >=2022.5.0, <2025.8.0 packaging >=20.0, <=25.0 -typing-extensions >=4.5.0, <4.15.0 +typing-extensions >4.5.0, <4.15.0 lightning-utilities >=0.10.0, <0.16.0 diff --git a/requirements/pytorch/base.txt b/requirements/pytorch/base.txt index e62ebce95053d..c5157430c9e2a 100644 --- a/requirements/pytorch/base.txt +++ b/requirements/pytorch/base.txt @@ -7,5 +7,5 @@ PyYAML >5.4, <6.1.0 fsspec[http] >=2022.5.0, <2025.8.0 torchmetrics >0.7.0, <1.9.0 packaging >=20.0, <=25.0 -typing-extensions >=4.5.0, <4.15.0 +typing-extensions >4.5.0, <4.15.0 lightning-utilities >=0.10.0, <0.16.0 diff --git a/requirements/pytorch/docs.txt b/requirements/pytorch/docs.txt index ac4a8df67ccf6..35cc6234ae5d2 100644 --- a/requirements/pytorch/docs.txt +++ b/requirements/pytorch/docs.txt @@ -4,4 +4,6 @@ nbformat # used for generate empty notebook ipython[notebook] <8.19.0 setuptools<81.0 # workaround for `error in ipython setup command: use_2to3 is invalid.` +onnxscript >= 0.2.2, <0.4.0 + #-r ../../_notebooks/.actions/requires.txt diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index dec156747f714..d721ec7130eee 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0 numpy >=1.17.2, <1.27.0 onnx >=1.12.0, <1.19.0 onnxruntime >=1.12.0, <1.21.0 +onnxscript >= 0.2.2, <0.4.0 psutil <7.0.1 # for `DeviceStatsMonitor` pandas >2.0, <2.4.0 # needed in benchmarks fastapi # for `ServableModuleValidator` # not setting version as re-defined in App diff --git a/src/lightning/fabric/utilities/imports.py b/src/lightning/fabric/utilities/imports.py index a618371d7f2b4..1962e336b3eb9 100644 --- a/src/lightning/fabric/utilities/imports.py +++ b/src/lightning/fabric/utilities/imports.py @@ -34,6 +34,7 @@ _TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0") _TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0") _TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1") +_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0") _TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0") _PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10) diff --git a/src/lightning/fabric/utilities/testing/_runif.py b/src/lightning/fabric/utilities/testing/_runif.py index 6f5d933f9dae3..a2c605f411f0a 100644 --- a/src/lightning/fabric/utilities/testing/_runif.py +++ b/src/lightning/fabric/utilities/testing/_runif.py @@ -17,7 +17,7 @@ from typing import Optional import torch -from lightning_utilities.core.imports import RequirementCache, compare_version +from lightning_utilities.core.imports import compare_version from packaging.version import Version from lightning.fabric.accelerators import XLAAccelerator @@ -112,9 +112,7 @@ def _runif_reasons( reasons.append("Standalone execution") kwargs["standalone"] = True - if deepspeed and not ( - _DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils") - ): + if deepspeed and not (_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4): reasons.append("Deepspeed") if dynamo: diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index b18b2d51ccf2c..14301c47281fa 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed -- +- Allow returning `ONNXProgram` when calling `to_onnx(dynamo=True)` ([#20811](https://github.com/Lightning-AI/pytorch-lightning/pull/20811)) ### Removed diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 905cc9a67c85a..6b8d9a368b001 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -47,6 +47,7 @@ from lightning.fabric.utilities.apply_func import convert_to_tensors from lightning.fabric.utilities.cloud_io import get_filesystem from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin +from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5 from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH from lightning.fabric.wrappers import _FabricOptimizer from lightning.pytorch.callbacks.callback import Callback @@ -60,7 +61,7 @@ from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype from lightning.pytorch.utilities import GradClipAlgorithmType from lightning.pytorch.utilities.exceptions import MisconfigurationException -from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1 +from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1 from lightning.pytorch.utilities.model_helpers import _restricted_classmethod from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature @@ -72,10 +73,17 @@ OptimizerLRScheduler, ) +_ONNX_AVAILABLE = RequirementCache("onnx") +_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript") + if TYPE_CHECKING: from torch.distributed.device_mesh import DeviceMesh -_ONNX_AVAILABLE = RequirementCache("onnx") + if _TORCH_GREATER_EQUAL_2_5: + if _TORCH_GREATER_EQUAL_2_6: + from torch.onnx import ONNXProgram + else: + from torch.onnx._internal.exporter import ONNXProgram # type: ignore[no-redef] warning_cache = WarningCache() log = logging.getLogger(__name__) @@ -1386,12 +1394,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None: ) @torch.no_grad() - def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None: + def to_onnx( + self, + file_path: Union[str, Path, BytesIO, None] = None, + input_sample: Optional[Any] = None, + **kwargs: Any, + ) -> Optional["ONNXProgram"]: """Saves the model in ONNX format. Args: - file_path: The path of the file the onnx model should be saved to. + file_path: The path of the file the onnx model should be saved to. Default: None (no file saved). input_sample: An input for tracing. Default: None (Use self.example_input_array) + **kwargs: Will be passed to torch.onnx.export function. Example:: @@ -1412,6 +1426,12 @@ def forward(self, x): if not _ONNX_AVAILABLE: raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.") + if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5): + raise ModuleNotFoundError( + f"`{type(self).__name__}.to_onnx(dynamo=True)` " + "requires `onnxscript` and `torch>=2.5.0` to be installed." + ) + mode = self.training if input_sample is None: @@ -1428,8 +1448,9 @@ def forward(self, x): file_path = str(file_path) if isinstance(file_path, Path) else file_path # PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but # BytesIO does work, too. - torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore + ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore self.train(mode) + return ret @torch.no_grad() def to_torchscript( diff --git a/src/lightning/pytorch/utilities/imports.py b/src/lightning/pytorch/utilities/imports.py index 6c0815a6af9dc..2d3855994a078 100644 --- a/src/lightning/pytorch/utilities/imports.py +++ b/src/lightning/pytorch/utilities/imports.py @@ -14,13 +14,15 @@ """General utilities.""" import functools +import operator import sys -from lightning_utilities.core.imports import RequirementCache, package_available +from lightning_utilities.core.imports import RequirementCache, compare_version, package_available from lightning.pytorch.utilities.rank_zero import rank_zero_warn _PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11) +_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0") _TORCHMETRICS_GREATER_EQUAL_0_8_0 = RequirementCache("torchmetrics>=0.8.0") _TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1") _TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task diff --git a/src/lightning/pytorch/utilities/testing/_runif.py b/src/lightning/pytorch/utilities/testing/_runif.py index 9c46913681143..5bb8f984b2749 100644 --- a/src/lightning/pytorch/utilities/testing/_runif.py +++ b/src/lightning/pytorch/utilities/testing/_runif.py @@ -18,7 +18,7 @@ from lightning.fabric.utilities.testing import _runif_reasons as fabric_run_if from lightning.pytorch.accelerators.cpu import _PSUTIL_AVAILABLE from lightning.pytorch.callbacks.progress.rich_progress import _RICH_AVAILABLE -from lightning.pytorch.core.module import _ONNX_AVAILABLE +from lightning.pytorch.core.module import _ONNX_AVAILABLE, _ONNXSCRIPT_AVAILABLE from lightning.pytorch.utilities.imports import _OMEGACONF_AVAILABLE _SKLEARN_AVAILABLE = RequirementCache("scikit-learn") @@ -42,6 +42,7 @@ def _runif_reasons( psutil: bool = False, sklearn: bool = False, onnx: bool = False, + onnxscript: bool = False, ) -> tuple[list[str], dict[str, bool]]: """Construct reasons for pytest skipif. @@ -64,6 +65,7 @@ def _runif_reasons( psutil: Require that psutil is installed. sklearn: Require that scikit-learn is installed. onnx: Require that onnx is installed. + onnxscript: Require that onnxscript is installed. """ @@ -96,4 +98,7 @@ def _runif_reasons( if onnx and not _ONNX_AVAILABLE: reasons.append("onnx") + if onnxscript and not _ONNXSCRIPT_AVAILABLE: + reasons.append("onnxscript") + return reasons, kwargs diff --git a/tests/tests_pytorch/models/test_onnx.py b/tests/tests_pytorch/models/test_onnx.py index 81fd5631a3400..9f51332fbbdfa 100644 --- a/tests/tests_pytorch/models/test_onnx.py +++ b/tests/tests_pytorch/models/test_onnx.py @@ -13,6 +13,7 @@ # limitations under the License. import operator import os +import re from io import BytesIO from pathlib import Path from unittest.mock import patch @@ -25,7 +26,9 @@ import tests_pytorch.helpers.pipelines as tpipes from lightning.pytorch import Trainer +from lightning.pytorch.core.module import _ONNXSCRIPT_AVAILABLE from lightning.pytorch.demos.boring_classes import BoringModel +from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6 from tests_pytorch.helpers.runif import RunIf from tests_pytorch.utilities.test_model_summary import UnorderedModel @@ -139,8 +142,16 @@ def test_error_if_no_input(tmp_path): model.to_onnx(file_path) +@pytest.mark.parametrize( + "dynamo", + [ + None, + pytest.param(False, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + pytest.param(True, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)), + ], +) @RunIf(onnx=True) -def test_if_inference_output_is_valid(tmp_path): +def test_if_inference_output_is_valid(tmp_path, dynamo): """Test that the output inferred from ONNX model is same as from PyTorch.""" model = BoringModel() model.example_input_array = torch.randn(5, 32) @@ -153,7 +164,12 @@ def test_if_inference_output_is_valid(tmp_path): torch_out = model(model.example_input_array) file_path = os.path.join(tmp_path, "model.onnx") - model.to_onnx(file_path, model.example_input_array, export_params=True) + kwargs = { + "export_params": True, + } + if dynamo is not None: + kwargs["dynamo"] = dynamo + model.to_onnx(file_path, model.example_input_array, **kwargs) ort_kwargs = {"providers": "CPUExecutionProvider"} if compare_version("onnxruntime", operator.ge, "1.16.0") else {} ort_session = onnxruntime.InferenceSession(file_path, **ort_kwargs) @@ -167,3 +183,53 @@ def to_numpy(tensor): # compare ONNX Runtime and PyTorch results assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05) + + +@RunIf(min_torch="2.5.0", dynamo=True) +@pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.") +def test_model_onnx_export_missing_onnxscript(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True) + + +@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True) +def test_model_return_type(): + if _TORCH_GREATER_EQUAL_2_6: + from torch.onnx import ONNXProgram + else: + from torch.onnx._internal.exporter import ONNXProgram + + model = BoringModel() + model.example_input_array = torch.randn((1, 32)) + model.eval() + + onnx_pg = model.to_onnx(dynamo=True) + assert isinstance(onnx_pg, ONNXProgram) + + model_ret = model(model.example_input_array) + inf_ret = onnx_pg(model.example_input_array) + assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05) + + +@RunIf(max_torch="2.5.0") +def test_model_onnx_export_wrong_torch_version(): + """Test that an error is raised if onnxscript is not available.""" + model = BoringModel() + model.example_input_array = torch.randn(5, 32) + + with pytest.raises( + ModuleNotFoundError, + match=re.escape( + f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.", + ), + ): + model.to_onnx(dynamo=True)