47
47
from lightning .fabric .utilities .apply_func import convert_to_tensors
48
48
from lightning .fabric .utilities .cloud_io import get_filesystem
49
49
from lightning .fabric .utilities .device_dtype_mixin import _DeviceDtypeModuleMixin
50
+ from lightning .fabric .utilities .imports import _TORCH_GREATER_EQUAL_2_5
50
51
from lightning .fabric .utilities .types import _MAP_LOCATION_TYPE , _PATH
51
52
from lightning .fabric .wrappers import _FabricOptimizer
52
53
from lightning .pytorch .callbacks .callback import Callback
60
61
from lightning .pytorch .trainer .connectors .logger_connector .result import _get_default_dtype
61
62
from lightning .pytorch .utilities import GradClipAlgorithmType
62
63
from lightning .pytorch .utilities .exceptions import MisconfigurationException
63
- from lightning .pytorch .utilities .imports import _TORCHMETRICS_GREATER_EQUAL_0_9_1
64
+ from lightning .pytorch .utilities .imports import _TORCH_GREATER_EQUAL_2_6 , _TORCHMETRICS_GREATER_EQUAL_0_9_1
64
65
from lightning .pytorch .utilities .model_helpers import _restricted_classmethod
65
66
from lightning .pytorch .utilities .rank_zero import WarningCache , rank_zero_warn
66
67
from lightning .pytorch .utilities .signature_utils import is_param_in_hook_signature
72
73
OptimizerLRScheduler ,
73
74
)
74
75
76
+ _ONNX_AVAILABLE = RequirementCache ("onnx" )
77
+ _ONNXSCRIPT_AVAILABLE = RequirementCache ("onnxscript" )
78
+
75
79
if TYPE_CHECKING :
76
80
from torch .distributed .device_mesh import DeviceMesh
77
81
78
- _ONNX_AVAILABLE = RequirementCache ("onnx" )
82
+ if _TORCH_GREATER_EQUAL_2_5 :
83
+ if _TORCH_GREATER_EQUAL_2_6 :
84
+ from torch .onnx import ONNXProgram
85
+ else :
86
+ from torch .onnx ._internal .exporter import ONNXProgram # type: ignore[no-redef]
79
87
80
88
warning_cache = WarningCache ()
81
89
log = logging .getLogger (__name__ )
@@ -1416,12 +1424,18 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
1416
1424
)
1417
1425
1418
1426
@torch .no_grad ()
1419
- def to_onnx (self , file_path : Union [str , Path , BytesIO ], input_sample : Optional [Any ] = None , ** kwargs : Any ) -> None :
1427
+ def to_onnx (
1428
+ self ,
1429
+ file_path : Union [str , Path , BytesIO , None ] = None ,
1430
+ input_sample : Optional [Any ] = None ,
1431
+ ** kwargs : Any ,
1432
+ ) -> Optional ["ONNXProgram" ]:
1420
1433
"""Saves the model in ONNX format.
1421
1434
1422
1435
Args:
1423
- file_path: The path of the file the onnx model should be saved to.
1436
+ file_path: The path of the file the onnx model should be saved to. Default: None (no file saved).
1424
1437
input_sample: An input for tracing. Default: None (Use self.example_input_array)
1438
+
1425
1439
**kwargs: Will be passed to torch.onnx.export function.
1426
1440
1427
1441
Example::
@@ -1442,6 +1456,12 @@ def forward(self, x):
1442
1456
if not _ONNX_AVAILABLE :
1443
1457
raise ModuleNotFoundError (f"`{ type (self ).__name__ } .to_onnx()` requires `onnx` to be installed." )
1444
1458
1459
+ if kwargs .get ("dynamo" , False ) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5 ):
1460
+ raise ModuleNotFoundError (
1461
+ f"`{ type (self ).__name__ } .to_onnx(dynamo=True)` "
1462
+ "requires `onnxscript` and `torch>=2.5.0` to be installed."
1463
+ )
1464
+
1445
1465
mode = self .training
1446
1466
1447
1467
if input_sample is None :
@@ -1458,8 +1478,9 @@ def forward(self, x):
1458
1478
file_path = str (file_path ) if isinstance (file_path , Path ) else file_path
1459
1479
# PyTorch (2.5) declares file_path to be str | PathLike[Any] | None, but
1460
1480
# BytesIO does work, too.
1461
- torch .onnx .export (self , input_sample , file_path , ** kwargs ) # type: ignore
1481
+ ret = torch .onnx .export (self , input_sample , file_path , ** kwargs ) # type: ignore
1462
1482
self .train (mode )
1483
+ return ret
1463
1484
1464
1485
@torch .no_grad ()
1465
1486
def to_torchscript (
0 commit comments