Skip to content

Commit 8f050ea

Browse files
committed
feat: enable ONNXProgram export on torch 2.5.0
1 parent 1396f35 commit 8f050ea

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

src/lightning/fabric/utilities/imports.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
_TORCH_EQUAL_2_4_0 = compare_version("torch", operator.eq, "2.4.0")
3535
_TORCH_GREATER_EQUAL_2_4 = compare_version("torch", operator.ge, "2.4.0")
3636
_TORCH_GREATER_EQUAL_2_4_1 = compare_version("torch", operator.ge, "2.4.1")
37+
_TORCH_GREATER_EQUAL_2_5 = compare_version("torch", operator.ge, "2.5.0")
3738
_TORCH_LESS_EQUAL_2_6 = compare_version("torch", operator.le, "2.6.0")
3839

3940
_PYTHON_GREATER_EQUAL_3_10_0 = (sys.version_info.major, sys.version_info.minor) >= (3, 10)

src/lightning/pytorch/core/module.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
from lightning.fabric.utilities.apply_func import convert_to_tensors
4848
from lightning.fabric.utilities.cloud_io import get_filesystem
4949
from lightning.fabric.utilities.device_dtype_mixin import _DeviceDtypeModuleMixin
50+
from lightning.fabric.utilities.imports import _TORCH_GREATER_EQUAL_2_5
5051
from lightning.fabric.utilities.types import _MAP_LOCATION_TYPE, _PATH
5152
from lightning.fabric.wrappers import _FabricOptimizer
5253
from lightning.pytorch.callbacks.callback import Callback
@@ -1394,9 +1395,10 @@ def forward(self, x):
13941395
if not _ONNX_AVAILABLE:
13951396
raise ModuleNotFoundError(f"`{type(self).__name__}.to_onnx()` requires `onnx` to be installed.")
13961397

1397-
if kwargs.get("dynamo", False) and not _ONNXSCRIPT_AVAILABLE:
1398+
if kwargs.get("dynamo", False) and not (_ONNXSCRIPT_AVAILABLE and _TORCH_GREATER_EQUAL_2_5):
13981399
raise ModuleNotFoundError(
1399-
f"`{type(self).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."
1400+
f"`{type(self).__name__}.to_onnx(dynamo=True)` "
1401+
"requires `onnxscript` and `torch>=2.5.0` to be installed."
14001402
)
14011403

14021404
mode = self.training

tests/tests_pytorch/models/test_onnx.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,8 @@ def test_error_if_no_input(tmp_path):
145145
"dynamo",
146146
[
147147
None,
148-
pytest.param(False, marks=RunIf(min_torch="2.6.0", dynamo=True, onnxscript=True)),
149-
pytest.param(True, marks=RunIf(min_torch="2.6.0", dynamo=True, onnxscript=True)),
148+
pytest.param(False, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)),
149+
pytest.param(True, marks=RunIf(min_torch="2.5.0", dynamo=True, onnxscript=True)),
150150
],
151151
)
152152
@RunIf(onnx=True)
@@ -184,7 +184,7 @@ def to_numpy(tensor):
184184
assert np.allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)
185185

186186

187-
@RunIf(min_torch="2.6.0", dynamo=True)
187+
@RunIf(min_torch="2.5.0", dynamo=True)
188188
@pytest.mark.skipif(_ONNXSCRIPT_AVAILABLE, reason="Run this test only if onnxscript is not available.")
189189
def test_model_onnx_export_missing_onnxscript():
190190
"""Test that an error is raised if onnxscript is not available."""
@@ -193,21 +193,41 @@ def test_model_onnx_export_missing_onnxscript():
193193

194194
with pytest.raises(
195195
ModuleNotFoundError,
196-
match=re.escape(f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` to be installed."),
196+
match=re.escape(
197+
f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.",
198+
),
197199
):
198200
model.to_onnx(dynamo=True)
199201

200202

201-
@RunIf(onnx=True, min_torch="2.6.0", dynamo=True, onnxscript=True)
203+
@RunIf(onnx=True, min_torch="2.5.0", dynamo=True, onnxscript=True)
202204
def test_model_return_type():
203205
model = BoringModel()
204206
model.example_input_array = torch.randn((1, 32))
205207
model.eval()
206208

207209
onnx_pg = model.to_onnx(dynamo=True)
208-
assert isinstance(onnx_pg, torch.onnx.ONNXProgram)
210+
211+
onnx_cls = torch.onnx.ONNXProgram if torch.__version__ >= "2.6.0" else torch.onnx._internal.exporter.ONNXProgram
212+
213+
assert isinstance(onnx_pg, onnx_cls)
209214

210215
model_ret = model(model.example_input_array)
211216
inf_ret = onnx_pg(model.example_input_array)
212217

213218
assert torch.allclose(model_ret, inf_ret[0], rtol=1e-03, atol=1e-05)
219+
220+
221+
@RunIf(max_torch="2.5.0")
222+
def test_model_onnx_export_wrong_torch_version():
223+
"""Test that an error is raised if onnxscript is not available."""
224+
model = BoringModel()
225+
model.example_input_array = torch.randn(5, 32)
226+
227+
with pytest.raises(
228+
ModuleNotFoundError,
229+
match=re.escape(
230+
f"`{type(model).__name__}.to_onnx(dynamo=True)` requires `onnxscript` and `torch>=2.5.0` to be installed.",
231+
),
232+
):
233+
model.to_onnx(dynamo=True)

0 commit comments

Comments
 (0)