File tree Expand file tree Collapse file tree 3 files changed +36
-39
lines changed
Expand file tree Collapse file tree 3 files changed +36
-39
lines changed Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22"""Tests for dynamic shape support."""
33
4- import io
4+ import tempfile
55
66
77import torch
@@ -20,18 +20,17 @@ def export_to_onnx_dynamic(
2020 model .eval ()
2121 dummy_input = torch .randn (* input_shape )
2222
23- buffer = io .BytesIO ()
24- torch .onnx .export (
25- model ,
26- dummy_input ,
27- buffer ,
28- input_names = ["input" ],
29- output_names = ["output" ],
30- dynamic_axes = dynamic_axes ,
31- opset_version = 17 ,
32- )
33- buffer .seek (0 )
34- return onnx .load (buffer )
23+ with tempfile .NamedTemporaryFile (suffix = ".onnx" ) as f :
24+ torch .onnx .export (
25+ model ,
26+ dummy_input ,
27+ f .name ,
28+ input_names = ["input" ],
29+ output_names = ["output" ],
30+ dynamic_axes = dynamic_axes ,
31+ opset_version = 17 ,
32+ )
33+ return onnx .load (f .name )
3534
3635
3736class TestDynamicBatchSize :
Original file line number Diff line number Diff line change 11# SPDX-License-Identifier: Apache-2.0
22"""End-to-end tests with real models."""
33
4- import io
4+ import tempfile
55
66import onnx
77import torch
@@ -18,18 +18,17 @@ def export_to_onnx(
1818 model .eval ()
1919 dummy_input = torch .randn (* input_shape )
2020
21- buffer = io .BytesIO ()
22- torch .onnx .export (
23- model ,
24- dummy_input ,
25- buffer ,
26- opset_version = opset_version ,
27- input_names = ["input" ],
28- output_names = ["output" ],
29- dynamic_axes = None ,
30- )
31- buffer .seek (0 )
32- return onnx .load (buffer )
21+ with tempfile .NamedTemporaryFile (suffix = ".onnx" ) as f :
22+ torch .onnx .export (
23+ model ,
24+ dummy_input ,
25+ f .name ,
26+ opset_version = opset_version ,
27+ input_names = ["input" ],
28+ output_names = ["output" ],
29+ dynamic_axes = None ,
30+ )
31+ return onnx .load (f .name )
3332
3433
3534def compare_outputs (
Original file line number Diff line number Diff line change 884. Handle train/eval modes correctly
99"""
1010
11- import io
11+ import tempfile
1212
1313import onnx
1414import pytest
@@ -26,18 +26,17 @@ def export_to_onnx(
2626 model .eval ()
2727 dummy_input = torch .randn (* input_shape )
2828
29- buffer = io .BytesIO ()
30- torch .onnx .export (
31- model ,
32- dummy_input ,
33- buffer ,
34- opset_version = opset_version ,
35- input_names = ["input" ],
36- output_names = ["output" ],
37- dynamic_axes = None ,
38- )
39- buffer .seek (0 )
40- return onnx .load (buffer )
29+ with tempfile .NamedTemporaryFile (suffix = ".onnx" ) as f :
30+ torch .onnx .export (
31+ model ,
32+ dummy_input ,
33+ f .name ,
34+ opset_version = opset_version ,
35+ input_names = ["input" ],
36+ output_names = ["output" ],
37+ dynamic_axes = None ,
38+ )
39+ return onnx .load (f .name )
4140
4241
4342class SimpleMLP (nn .Module ):
You can’t perform that action at this time.
0 commit comments