Skip to content

Commit a17d7d6

Browse files
committed
update torch to onnx conversion
1 parent 1ed1c6f commit a17d7d6

File tree

4 files changed

+200
-128
lines changed

4 files changed

+200
-128
lines changed

src/bioimageio/core/weight_converters/_utils_onnx.py

Lines changed: 0 additions & 49 deletions
This file was deleted.
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
"""helper to export both TorchScript or PytorchStateDict to ONNX"""
2+
3+
from collections import defaultdict
4+
from itertools import chain
5+
from pathlib import Path
6+
from typing import TYPE_CHECKING, DefaultDict, Dict, List, Literal, Tuple, Union
7+
8+
import torch
9+
from bioimageio.spec.model.v0_5 import (
10+
FileDescr,
11+
InputAxis,
12+
ModelDescr,
13+
OnnxWeightsDescr,
14+
ParameterizedSize,
15+
SizeReference,
16+
)
17+
from loguru import logger
18+
from typing_extensions import assert_never
19+
20+
from .. import __version__
21+
from ..digest_spec import get_member_id, get_test_input_sample
22+
from ..proc_setup import get_pre_and_postprocessing
23+
24+
if TYPE_CHECKING:
25+
import torch.jit
26+
from torch.export.dynamic_shapes import (
27+
_DimHint as DimHint, # pyright: ignore[reportPrivateUsage]
28+
)
29+
30+
31+
def _get_dynamic_axes_noop(model_descr: ModelDescr):
32+
"""noop for dynamo=True which uses `get_dynamic_shapes` instead"""
33+
34+
return None
35+
36+
37+
def _get_dynamic_axes_impl(model_descr: ModelDescr):
38+
"""dynamic axes for (old) onnx export with dynamo=False"""
39+
dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict)
40+
for d in chain(model_descr.inputs, model_descr.outputs):
41+
for i, ax in enumerate(d.axes):
42+
if not isinstance(ax.size, int):
43+
dynamic_axes[str(d.id)][i] = str(ax.id)
44+
45+
return dynamic_axes
46+
47+
48+
try:
49+
from torch.export import Dim
50+
51+
STATIC_DIM = Dim.STATIC if hasattr(Dim, "STATIC") else None
52+
TensorDim = Union[Dim, "DimHint", None]
53+
54+
except Exception as e:
55+
use_dynamo = False
56+
logger.info(f"Not using torch dynamo for ONNX export due to:\n{e}")
57+
58+
def _get_dynamic_shapes_noop(model_descr: ModelDescr):
59+
"""noop for dynamo=False which uses `get_dynamic_axes` instead"""
60+
61+
return None
62+
63+
get_dynamic_shapes = _get_dynamic_shapes_noop
64+
get_dynamic_axes = _get_dynamic_axes_impl
65+
else:
66+
use_dynamo = True
67+
logger.info("Using torch dynamo for ONNX export")
68+
69+
def _get_dynamic_shapes_impl(model_descr: ModelDescr):
70+
"""Get dynamic shapes for torch dynamo export"""
71+
# dynamic shapes as list to match the source code which may have
72+
# different arg names than the tensor ids in the model description
73+
74+
dynamic_shapes: List[Dict[int, TensorDim]] = []
75+
potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {}
76+
# add dynamic dims from parameterized input sizes (and fixed sizes as None)
77+
for d in model_descr.inputs:
78+
dynamic_tensor_dims: Dict[int, TensorDim] = {}
79+
for i, ax in enumerate(d.axes):
80+
dim_name = f"{d.id}_{ax.id}"
81+
if isinstance(ax.size, int):
82+
dim = STATIC_DIM # fixed size
83+
elif ax.size is None:
84+
dim = Dim(dim_name, min=1)
85+
elif isinstance(ax.size, ParameterizedSize):
86+
dim = Dim(dim_name, min=ax.size.min)
87+
elif isinstance(ax.size, SizeReference):
88+
continue # handled below
89+
else:
90+
assert_never(ax.size)
91+
92+
dynamic_tensor_dims[i] = dim
93+
potential_ref_axes[dim_name] = (ax, i)
94+
95+
dynamic_shapes.append(dynamic_tensor_dims)
96+
97+
# add dynamic dims from size references
98+
for d, dynamic_tensor_dims in zip(model_descr.inputs, dynamic_shapes):
99+
for i, ax in enumerate(d.axes):
100+
if not isinstance(ax.size, SizeReference):
101+
continue # handled above
102+
103+
dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}"
104+
ax_ref, i_ref = potential_ref_axes[dim_name_ref]
105+
dim_ref = dynamic_tensor_dims[i_ref]
106+
if isinstance(dim_ref, Dim):
107+
a = ax_ref.scale / ax.scale
108+
b = ax.size.offset
109+
dim = a * dim_ref + b
110+
else:
111+
dim = STATIC_DIM
112+
113+
dynamic_tensor_dims[i] = dim
114+
115+
return dynamic_shapes
116+
117+
get_dynamic_shapes = _get_dynamic_shapes_impl
118+
get_dynamic_axes = _get_dynamic_axes_noop
119+
120+
121+
def export_to_onnx(
122+
model_descr: ModelDescr,
123+
model: Union[torch.nn.Module, "torch.jit.ScriptModule"],
124+
output_path: Path,
125+
verbose: bool,
126+
opset_version: int,
127+
parent: Literal["torchscript", "pytorch_state_dict"],
128+
) -> OnnxWeightsDescr:
129+
sample = get_test_input_sample(model_descr)
130+
procs = get_pre_and_postprocessing(
131+
model_descr, dataset_for_initial_statistics=[sample]
132+
)
133+
procs.pre(sample)
134+
inputs_numpy = [
135+
sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
136+
]
137+
inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
138+
139+
save_weights_externally = use_dynamo
140+
with torch.no_grad():
141+
outputs_original_torch = model(*inputs_torch)
142+
if isinstance(outputs_original_torch, torch.Tensor):
143+
outputs_original_torch = [outputs_original_torch]
144+
145+
_ = torch.onnx.export(
146+
model,
147+
tuple(inputs_torch),
148+
str(output_path),
149+
dynamo=use_dynamo,
150+
external_data=save_weights_externally,
151+
input_names=[str(d.id) for d in model_descr.inputs],
152+
output_names=[str(d.id) for d in model_descr.outputs],
153+
dynamic_axes=get_dynamic_axes(model_descr),
154+
dynamic_shapes=get_dynamic_shapes(model_descr),
155+
verbose=verbose,
156+
opset_version=opset_version,
157+
)
158+
159+
if save_weights_externally:
160+
external_data_path = output_path.with_suffix(
161+
output_path.suffix + ".data"
162+
).absolute()
163+
if not external_data_path.exists():
164+
raise FileNotFoundError(
165+
f"Expected external data file at {external_data_path} not found."
166+
)
167+
external_data_descr = FileDescr(source=external_data_path)
168+
else:
169+
external_data_descr = None
170+
171+
return OnnxWeightsDescr(
172+
source=output_path.absolute(),
173+
external_data=external_data_descr,
174+
parent=parent,
175+
opset_version=opset_version,
176+
comment=f"Converted with bioimageio.core {__version__}, dynamo={use_dynamo}.",
177+
)
Lines changed: 10 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,17 @@
11
from pathlib import Path
22

3-
import torch
43
from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
54

6-
from .. import __version__
75
from ..backends.pytorch_backend import load_torch_model
8-
from ..digest_spec import get_member_id, get_test_input_sample
9-
from ..proc_setup import get_pre_and_postprocessing
10-
from ._utils_onnx import get_dynamic_shapes
6+
from ._utils_torch_onnx import export_to_onnx
117

128

139
def convert(
1410
model_descr: ModelDescr,
1511
output_path: Path,
1612
*,
1713
verbose: bool = False,
18-
opset_version: int = 15,
14+
opset_version: int = 18,
1915
) -> OnnxWeightsDescr:
2016
"""
2117
Convert model weights from the Torchscript state_dict format to the ONNX format.
@@ -28,14 +24,14 @@ def convert(
2824
verbose:
2925
If True, will print out detailed information during the ONNX export process. Defaults to False.
3026
opset_version:
31-
The ONNX opset version to use for the export. Defaults to 15.
27+
The ONNX opset version to use for the export. Defaults to 18.
3228
3329
Raises:
3430
ValueError:
3531
If the provided model does not have weights in the PyTorch state_dict format.
3632
3733
Returns:
38-
A descriptor object that contains information about the exported ONNX weights.
34+
A description of the exported ONNX weights.
3935
"""
4036

4137
state_dict_weights_descr = model_descr.weights.pytorch_state_dict
@@ -44,36 +40,13 @@ def convert(
4440
"The provided model does not have weights in the pytorch state dict format"
4541
)
4642

47-
sample = get_test_input_sample(model_descr)
48-
procs = get_pre_and_postprocessing(
49-
model_descr, dataset_for_initial_statistics=[sample]
50-
)
51-
procs.pre(sample)
52-
inputs_numpy = [
53-
sample.members[get_member_id(ipt)].data.data for ipt in model_descr.inputs
54-
]
55-
inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
5643
model = load_torch_model(state_dict_weights_descr, load_state=True)
57-
with torch.no_grad():
58-
outputs_original_torch = model(*inputs_torch)
59-
if isinstance(outputs_original_torch, torch.Tensor):
60-
outputs_original_torch = [outputs_original_torch]
61-
62-
_ = torch.onnx.export(
63-
model,
64-
tuple(inputs_torch),
65-
str(output_path),
66-
dynamo=True,
67-
input_names=[str(d.id) for d in model_descr.inputs],
68-
output_names=[str(d.id) for d in model_descr.outputs],
69-
dynamic_shapes=get_dynamic_shapes(model_descr),
70-
verbose=verbose,
71-
opset_version=opset_version,
72-
)
7344

74-
return OnnxWeightsDescr(
75-
source=output_path.absolute(),
45+
return export_to_onnx(
46+
model_descr,
47+
model,
48+
output_path,
49+
verbose,
50+
opset_version,
7651
parent="pytorch_state_dict",
77-
opset_version=opset_version,
78-
comment=f"Converted with bioimageio.core {__version__}.",
7952
)

0 commit comments

Comments
 (0)