Skip to content

Commit 577c04d

Browse files
GdoongMathewBordapre-commit-ci[bot]
authored
Allow to return ONNXProgram when calling to_onnx(dynamo=True) (#20811)
* feat: return `ONNXProgram` when exporting with dynamo=True. * test: add to_onnx(dynamo=True) unittests. * fix: add ignore filter in pyproject.toml * fix: change the return type annotation of `to_onnx`. * test: add parametrized `dynamo` to test `test_if_inference_output_is_valid`. * test: add difference check in `test_model_return_type`. * fix: fix unittest. * test: add test `test_model_onnx_export_missing_onnxscript`. * feat: enable ONNXProgram export on torch 2.5.0 * extensions --------- Co-authored-by: Jirka B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 105bb20 commit 577c04d

File tree

12 files changed

+113
-16
lines changed

12 files changed

+113
-16
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ markers = [
180180
]
181181
filterwarnings = [
182182
"error::FutureWarning",
183+
"ignore::FutureWarning:onnxscript", # Temporary ignore until onnxscript is updated
183184
]
184185
xfail_strict = true
185186
junit_duration_report = "call"

requirements/fabric/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,5 @@
44
torch >=2.1.0, <2.9.0
55
fsspec[http] >=2022.5.0, <2025.8.0
66
packaging >=20.0, <=25.0
7-
typing-extensions >=4.5.0, <4.15.0
7+
typing-extensions >4.5.0, <4.15.0
88
lightning-utilities >=0.10.0, <0.16.0

requirements/pytorch/base.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@ PyYAML >5.4, <6.1.0
77
fsspec[http] >=2022.5.0, <2025.8.0
88
torchmetrics >0.7.0, <1.9.0
99
packaging >=20.0, <=25.0
10-
typing-extensions >=4.5.0, <4.15.0
10+
typing-extensions >4.5.0, <4.15.0
1111
lightning-utilities >=0.10.0, <0.16.0

requirements/pytorch/docs.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,6 @@ nbformat # used for generate empty notebook
44
ipython[notebook] <8.19.0
55
setuptools<81.0 # workaround for `error in ipython setup command: use_2to3 is invalid.`
66

7+
onnxscript >= 0.2.2, <0.4.0
8+
79
#-r ../../_notebooks/.actions/requires.txt

requirements/pytorch/test.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ scikit-learn >0.22.1, <1.7.0
1111
numpy >=1.17.2, <1.27.0
1212
onnx >=1.12.0, <1.19.0
1313
onnxruntime >=1.12.0, <1.21.0
14+
onnxscript >= 0.2.2, <0.4.0
1415
psutil <7.0.1 # for `DeviceStatsMonitor`
1516
pandas >2.0, <2.4.0 # needed in benchmarks
1617
fastapi # for `ServableModuleValidator` # not setting version as re-defined in App

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
3535
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
37+
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
3738
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
3839

3940
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/fabric/utilities/testing/_runif.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Optional
1818

1919
import torch
20-
from lightning_utilities.core.imports import RequirementCache, compare_version
20+
from lightning_utilities.core.imports import compare_version
2121
from packaging.version import Version
2222

2323
from lightning.fabric.accelerators import XLAAccelerator
@@ -112,9 +112,7 @@ def _runif_reasons(
112112
reasons.append("Standalone execution")
113113
kwargs["standalone"] = True
114114

115-
if deepspeed and not (
116-
_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4 and RequirementCache(module="deepspeed.utils")
117-
):
115+
if deepspeed and not (_DEEPSPEED_AVAILABLE and not _TORCH_GREATER_EQUAL_2_4):
118116
reasons.append("Deepspeed")
119117

120118
if dynamo:

src/lightning/pytorch/CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
### Changed
1818

19-
-
19+
- Allow returning `ONNXProgram` when calling `to_onnx(dynamo=True)` ([#20811](https://github.com/Lightning-AI/pytorch-lightning/pull/20811))
2020

2121

2222
### Removed

src/lightning/pytorch/core/module.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from lightning.fabric.utilities.apply_func import convert_to_tensors
4848
from lightning.fabric.utilities.cloud_io import get_filesystem
4949
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
50+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5
5051
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
5152
from lightning.fabric.wrappers import _FabricOptimizer
5253
from lightning.pytorch.callbacks.callback import Callback
@@ -60,7 +61,7 @@
6061
from lightning.pytorch.trainer.connectors.logger_connector.result import _get_default_dtype
6162
from lightning.pytorch.utilities import GradClipAlgorithmType
6263
from lightning.pytorch.utilities.exceptions import MisconfigurationException
63-
from lightning.pytorch.utilities.imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
64+
from lightning.pytorch.utilities.imports import _TORCH_GREATER_EQUAL_2_6, _TORCHMETRICS_GREATER_EQUAL_0_9_1
6465
from lightning.pytorch.utilities.model_helpers import _restricted_classmethod
6566
from lightning.pytorch.utilities.rank_zero import WarningCache, rank_zero_warn
6667
from lightning.pytorch.utilities.signature_utils import is_param_in_hook_signature
@@ -72,10 +73,17 @@
7273
OptimizerLRScheduler,
7374
)
7475

76+
_ONNX_AVAILABLE = RequirementCache("onnx")
77+
_ONNXSCRIPT_AVAILABLE = RequirementCache("onnxscript")
78+
7579
if TYPE_CHECKING:
7680
from torch.distributed.device_mesh import DeviceMesh
7781

78-
_ONNX_AVAILABLE = RequirementCache("onnx")
82+
if _TORCH_GREATER_EQUAL_2_5:
83+
if _TORCH_GREATER_EQUAL_2_6:
84+
from torch.onnx import ONNXProgram
85+
else:
86+
from torch.onnx._internal.exporter import ONNXProgram # type: ignore[no-redef]
7987

8088
warning_cache = WarningCache()
8189
log = logging.getLogger(__name__)
@@ -1416,12 +1424,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
14161424
)
14171425

14181426
@torch.no_grad()
1419-
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
1427+
def to_onnx(
1428+
self,
1429+
file_path: Union[str, Path, BytesIO, None] = None,
1430+
input_sample: Optional[Any] = None,
1431+
**kwargs: Any,
1432+
) -> Optional["ONNXProgram"]:
14201433
"""Saves the model in ONNX format.
14211434
14221435
Args:
1423-
file_path: The path of the file the onnx model should be saved to.
1436+
file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
14241437
input_sample: An input for tracing. Default: None (Use self.example_input_array)
1438+
14251439
**kwargs: Will be passed to torch.onnx.export function.
14261440
14271441
Example::
@@ -1442,6 +1456,12 @@ def forward(self, x):
14421456
if not _ONNX_AVAILABLE:
14431457
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")
14441458

1459+
if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5):
1460+
raise ModuleNotFoundError(
1461+
f"`{type(self).__name__}.to_onnx(dynamo=True)` "
1462+
"requires `onnxscript` and `torch>=2.5.0` to be installed."
1463+
)
1464+
14451465
mode = self.training
14461466

14471467
if input_sample is None:
@@ -1458,8 +1478,9 @@ def forward(self, x):
14581478
file_path = str(file_path) if isinstance(file_path, Path) else file_path
14591479
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
14601480
# BytesIO does work, too.
1461-
torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
1481+
ret = torch.onnx.export(self, input_sample, file_path, **kwargs) # type: ignore
14621482
self.train(mode)
1483+
return ret
14631484

14641485
@torch.no_grad()
14651486
def to_torchscript(

src/lightning/pytorch/utilities/imports.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
"""General utilities."""
1515

1616
import functools
17+
import operator
1718
import sys
1819

19-
from lightning_utilities.core.imports import RequirementCache, package_available
20+
from lightning_utilities.core.imports import RequirementCache, compare_version, package_available
2021

2122
from lightning.pytorch.utilities.rank_zero import rank_zero_warn
2223

2324
_PYTHON_GREATER_EQUAL_3_11_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 11)
25+
_TORCH_GREATER_EQUAL_2_6 = compare_version("torch", operator.ge, "2.6.0")
2426
_TORCHMETRICS_GREATER_EQUAL_0_8_0 = RequirementCache("torchmetrics>=0.8.0")
2527
_TORCHMETRICS_GREATER_EQUAL_0_9_1 = RequirementCache("torchmetrics>=0.9.1")
2628
_TORCHMETRICS_GREATER_EQUAL_0_11 = RequirementCache("torchmetrics>=0.11.0") # using new API with task

0 commit comments

Comments
 (0)