Skip to content

Commit 828fd99

Browse files
authored
Re-enable passing BytesIO as path in .to_onnx() (#20172)
1 parent be0ae06 commit 828fd99

File tree

2 files changed

+9
-2
lines changed

2 files changed

+9
-2
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import numbers
1818
import weakref
1919
from contextlib import contextmanager
20+
from io import BytesIO
2021
from pathlib import Path
2122
from typing import (
2223
IO,
@@ -1364,7 +1365,7 @@ def _verify_is_manual_optimization(self, fn_name: str) -> None:
13641365
)
13651366

13661367
@torch.no_grad()
1367-
def to_onnx(self, file_path: Union[str, Path], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
1368+
def to_onnx(self, file_path: Union[str, Path, BytesIO], input_sample: Optional[Any] = None, **kwargs: Any) -> None:
13681369
"""Saves the model in ONNX format.
13691370
13701371
Args:
@@ -1403,7 +1404,8 @@ def forward(self, x):
14031404
input_sample = self._on_before_batch_transfer(input_sample)
14041405
input_sample = self._apply_batch_transfer_handler(input_sample)
14051406

1406-
torch.onnx.export(self, input_sample, str(file_path), **kwargs)
1407+
file_path = str(file_path) if isinstance(file_path, Path) else file_path
1408+
torch.onnx.export(self, input_sample, file_path, **kwargs)
14071409
self.train(mode)
14081410

14091411
@torch.no_grad()

tests/tests_pytorch/models/test_onnx.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import operator
1515
import os
16+
from io import BytesIO
1617
from pathlib import Path
1718
from unittest.mock import patch
1819

@@ -45,6 +46,10 @@ def test_model_saves_with_input_sample(tmp_path):
4546
assert os.path.isfile(file_path)
4647
assert os.path.getsize(file_path) > 4e2
4748

49+
file_path = BytesIO()
50+
model.to_onnx(file_path=file_path, input_sample=input_sample)
51+
assert len(file_path.getvalue()) > 4e2
52+
4853

4954
@pytest.mark.parametrize(
5055
"accelerator", [pytest.param("mps", marks=RunIf(mps=True)), pytest.param("gpu", marks=RunIf(min_cuda_gpus=True))]

0 commit comments

Comments
 (0)