Skip to content

Commit 2766727

Browse files
committed
update torch export
1 parent 82238d5 commit 2766727

File tree

3 files changed

+49
-15
lines changed

3 files changed

+49
-15
lines changed
Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,49 @@
11
from collections import defaultdict
2-
from itertools import chain
3-
from typing import DefaultDict, Dict
2+
from typing import DefaultDict, Dict, Optional, Tuple
43

5-
from bioimageio.spec.model.v0_5 import ModelDescr
4+
from bioimageio.spec.model.v0_5 import (
5+
InputAxis,
6+
ModelDescr,
7+
ParameterizedSize,
8+
SizeReference,
9+
)
10+
from torch.export import Dim
11+
from typing_extensions import assert_never
612

713

8-
def get_dynamic_axes(model_descr: ModelDescr):
9-
dynamic_axes: DefaultDict[str, Dict[int, str]] = defaultdict(dict)
10-
for d in chain(model_descr.inputs, model_descr.outputs):
14+
def get_dynamic_shapes(model_descr: ModelDescr):
15+
dynamic_shapes: DefaultDict[str, Dict[int, Optional[Dim]]] = defaultdict(dict)
16+
potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {}
17+
# add dynamic dims from parameterized input sizes (and fixed sizes as None)
18+
for d in model_descr.inputs:
1119
for i, ax in enumerate(d.axes):
12-
if not isinstance(ax.size, int):
13-
dynamic_axes[str(d.id)][i] = str(ax.id)
20+
dim_name = f"{d.id}_{ax.id}"
21+
if isinstance(ax.size, int):
22+
dim = None # fixed size (could also be left out)
23+
elif ax.size is None:
24+
dim = Dim(dim_name, min=1)
25+
elif isinstance(ax.size, ParameterizedSize):
26+
dim = Dim(dim_name, min=ax.size.min)
27+
elif isinstance(ax.size, SizeReference):
28+
continue # handled below
29+
else:
30+
assert_never(ax.size)
1431

15-
return dynamic_axes
32+
dynamic_shapes[str(d.id)][i] = dim
33+
potential_ref_axes[dim_name] = (ax, i)
34+
35+
# add dynamic dims from size references
36+
for d in model_descr.inputs:
37+
for i, ax in enumerate(d.axes):
38+
if not isinstance(ax.size, SizeReference):
39+
continue # handled above
40+
41+
dim_name_ref = f"{ax.size.tensor_id}_{ax.size.axis_id}"
42+
ax_ref, i_ref = potential_ref_axes[dim_name_ref]
43+
a = ax_ref.scale / ax.scale
44+
b = ax.size.offset
45+
dim_ref = dynamic_shapes[str(ax.size.tensor_id)][i_ref]
46+
dim = a * dim_ref + b if dim_ref is not None else None
47+
dynamic_shapes[str(d.id)][i] = dim
48+
49+
return dict(dynamic_shapes)

src/bioimageio/core/weight_converters/pytorch_to_onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
from pathlib import Path
22

33
import torch
4-
54
from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
65

76
from .. import __version__
87
from ..backends.pytorch_backend import load_torch_model
98
from ..digest_spec import get_member_id, get_test_input_sample
109
from ..proc_setup import get_pre_and_postprocessing
11-
from ._utils_onnx import get_dynamic_axes
10+
from ._utils_onnx import get_dynamic_shapes
1211

1312

1413
def convert(
@@ -64,9 +63,10 @@ def convert(
6463
model,
6564
tuple(inputs_torch),
6665
str(output_path),
66+
dynamo=True,
6767
input_names=[str(d.id) for d in model_descr.inputs],
6868
output_names=[str(d.id) for d in model_descr.outputs],
69-
dynamic_axes=get_dynamic_axes(model_descr),
69+
dynamic_shapes=get_dynamic_shapes(model_descr),
7070
verbose=verbose,
7171
opset_version=opset_version,
7272
)

src/bioimageio/core/weight_converters/torchscript_to_onnx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from pathlib import Path
22

33
import torch.jit
4-
54
from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
65

76
from .. import __version__
87
from ..digest_spec import get_member_id, get_test_input_sample
98
from ..proc_setup import get_pre_and_postprocessing
10-
from ._utils_onnx import get_dynamic_axes
9+
from ._utils_onnx import get_dynamic_shapes
1110

1211

1312
def convert(
@@ -68,9 +67,10 @@ def convert(
6867
model, # type: ignore
6968
tuple(inputs_torch),
7069
str(output_path),
70+
dynamo=True,
7171
input_names=[str(d.id) for d in model_descr.inputs],
7272
output_names=[str(d.id) for d in model_descr.outputs],
73-
dynamic_axes=get_dynamic_axes(model_descr),
73+
dynamic_shapes=get_dynamic_shapes(model_descr),
7474
verbose=verbose,
7575
opset_version=opset_version,
7676
)

0 commit comments

Comments
 (0)