Skip to content

Commit cdc1113

Browse files
Copilotmshr-h
andauthored
Replace BytesIO with tempfile for ONNX export in tests (#5)
* Initial plan * fix: replace BytesIO with tempfile for ONNX export to fix DeprecationWarning Co-authored-by: mshr-h <8973217+mshr-h@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: mshr-h <8973217+mshr-h@users.noreply.github.com>
1 parent 6a6651f commit cdc1113

File tree

3 files changed

+36
-39
lines changed

3 files changed

+36
-39
lines changed

tests/test_dynamic_shapes.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""Tests for dynamic shape support."""
33

4-
import io
4+
import tempfile
55

66

77
import torch
@@ -20,18 +20,17 @@ def export_to_onnx_dynamic(
2020
model.eval()
2121
dummy_input = torch.randn(*input_shape)
2222

23-
buffer = io.BytesIO()
24-
torch.onnx.export(
25-
model,
26-
dummy_input,
27-
buffer,
28-
input_names=["input"],
29-
output_names=["output"],
30-
dynamic_axes=dynamic_axes,
31-
opset_version=17,
32-
)
33-
buffer.seek(0)
34-
return onnx.load(buffer)
23+
with tempfile.NamedTemporaryFile(suffix=".onnx") as f:
24+
torch.onnx.export(
25+
model,
26+
dummy_input,
27+
f.name,
28+
input_names=["input"],
29+
output_names=["output"],
30+
dynamic_axes=dynamic_axes,
31+
opset_version=17,
32+
)
33+
return onnx.load(f.name)
3534

3635

3736
class TestDynamicBatchSize:

tests/test_e2e_models.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22
"""End-to-end tests with real models."""
33

4-
import io
4+
import tempfile
55

66
import onnx
77
import torch
@@ -18,18 +18,17 @@ def export_to_onnx(
1818
model.eval()
1919
dummy_input = torch.randn(*input_shape)
2020

21-
buffer = io.BytesIO()
22-
torch.onnx.export(
23-
model,
24-
dummy_input,
25-
buffer,
26-
opset_version=opset_version,
27-
input_names=["input"],
28-
output_names=["output"],
29-
dynamic_axes=None,
30-
)
31-
buffer.seek(0)
32-
return onnx.load(buffer)
21+
with tempfile.NamedTemporaryFile(suffix=".onnx") as f:
22+
torch.onnx.export(
23+
model,
24+
dummy_input,
25+
f.name,
26+
opset_version=opset_version,
27+
input_names=["input"],
28+
output_names=["output"],
29+
dynamic_axes=None,
30+
)
31+
return onnx.load(f.name)
3332

3433

3534
def compare_outputs(

tests/test_training.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
4. Handle train/eval modes correctly
99
"""
1010

11-
import io
11+
import tempfile
1212

1313
import onnx
1414
import pytest
@@ -26,18 +26,17 @@ def export_to_onnx(
2626
model.eval()
2727
dummy_input = torch.randn(*input_shape)
2828

29-
buffer = io.BytesIO()
30-
torch.onnx.export(
31-
model,
32-
dummy_input,
33-
buffer,
34-
opset_version=opset_version,
35-
input_names=["input"],
36-
output_names=["output"],
37-
dynamic_axes=None,
38-
)
39-
buffer.seek(0)
40-
return onnx.load(buffer)
29+
with tempfile.NamedTemporaryFile(suffix=".onnx") as f:
30+
torch.onnx.export(
31+
model,
32+
dummy_input,
33+
f.name,
34+
opset_version=opset_version,
35+
input_names=["input"],
36+
output_names=["output"],
37+
dynamic_axes=None,
38+
)
39+
return onnx.load(f.name)
4140

4241

4342
class SimpleMLP(nn.Module):

0 commit comments

Comments
 (0)