Skip to content

Commit e3891d5

Browse files
committed
improve dynamo export and update onnx adapter
(for external weight data files)
1 parent c80ba2f commit e3891d5

File tree

3 files changed

+73
-18
lines changed

3 files changed

+73
-18
lines changed

src/bioimageio/core/backends/onnx_backend.py

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
# pyright: reportUnknownVariableType=false
2+
import shutil
3+
import tempfile
24
import warnings
5+
from pathlib import Path
36
from typing import Any, List, Optional, Sequence, Union
47

58
import onnxruntime as rt # pyright: ignore[reportMissingTypeStubs]
6-
from numpy.typing import NDArray
7-
89
from bioimageio.spec.model import v0_4, v0_5
10+
from loguru import logger
11+
from numpy.typing import NDArray
912

1013
from ..model_adapters import ModelAdapter
1114
from ..utils._type_guards import is_list, is_tuple
@@ -20,11 +23,63 @@ def __init__(
2023
):
2124
super().__init__(model_description=model_description)
2225

23-
if model_description.weights.onnx is None:
26+
onnx_descr = model_description.weights.onnx
27+
if onnx_descr is None:
2428
raise ValueError("No ONNX weights specified for {model_description.name}")
2529

26-
reader = model_description.weights.onnx.get_reader()
27-
self._session = rt.InferenceSession(reader.read())
30+
providers = None
31+
if hasattr(rt, "get_available_providers"):
32+
providers = rt.get_available_providers()
33+
34+
if (
35+
isinstance(onnx_descr, v0_5.OnnxWeightsDescr)
36+
and onnx_descr.external_data is not None
37+
):
38+
src = onnx_descr.source.absolute()
39+
src_data = onnx_descr.external_data.source.absolute()
40+
if (
41+
isinstance(src, Path)
42+
and isinstance(src_data, Path)
43+
and src.parent == src_data.parent
44+
):
45+
logger.debug(
46+
"Loading ONNX model with external data from {}",
47+
src.parent,
48+
)
49+
self._session = rt.InferenceSession(
50+
src,
51+
providers=providers, # pyright: ignore[reportUnknownArgumentType]
52+
)
53+
else:
54+
src_reader = onnx_descr.get_reader()
55+
src_data_reader = onnx_descr.external_data.get_reader()
56+
with tempfile.TemporaryDirectory() as tmpdir:
57+
logger.debug(
58+
"Loading ONNX model with external data from {}",
59+
tmpdir,
60+
)
61+
src = Path(tmpdir) / src_reader.original_file_name
62+
src_data = Path(tmpdir) / src_data_reader.original_file_name
63+
with src.open("wb") as f:
64+
shutil.copyfileobj(src_reader, f)
65+
with src_data.open("wb") as f:
66+
shutil.copyfileobj(src_data_reader, f)
67+
68+
self._session = rt.InferenceSession(
69+
src,
70+
providers=providers, # pyright: ignore[reportUnknownArgumentType]
71+
)
72+
else:
73+
# load single source file from bytes (without external data, so probably <2GB)
74+
logger.debug(
75+
"Loading ONNX model from bytes (read from {})", onnx_descr.source
76+
)
77+
reader = onnx_descr.get_reader()
78+
self._session = rt.InferenceSession(
79+
reader.read(),
80+
providers=providers, # pyright: ignore[reportUnknownArgumentType]
81+
)
82+
2883
onnx_inputs = self._session.get_inputs()
2984
self._input_names: List[str] = [ipt.name for ipt in onnx_inputs]
3085

src/bioimageio/core/weight_converters/_utils_torch_onnx.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import torch
99
from bioimageio.spec.model.v0_5 import (
10+
BatchAxis,
1011
FileDescr,
1112
InputAxis,
1213
ModelDescr,
@@ -71,17 +72,17 @@ def _get_dynamic_shapes_impl(model_descr: ModelDescr):
7172
# dynamic shapes as list to match the source code which may have
7273
# different arg names than the tensor ids in the model description
7374

74-
dynamic_shapes: List[Dict[int, TensorDim]] = []
75+
dynamic_shapes: List[Dict[int, Union[int, TensorDim]]] = []
7576
potential_ref_axes: Dict[str, Tuple[InputAxis, int]] = {}
7677
# add dynamic dims from parameterized input sizes (and fixed sizes as None)
7778
for d in model_descr.inputs:
78-
dynamic_tensor_dims: Dict[int, TensorDim] = {}
79+
dynamic_tensor_dims: Dict[int, Union[int, TensorDim]] = {}
7980
for i, ax in enumerate(d.axes):
8081
dim_name = f"{d.id}_{ax.id}"
8182
if isinstance(ax.size, int):
82-
dim = STATIC_DIM # fixed size
83-
elif ax.size is None:
84-
dim = Dim(dim_name, min=1)
83+
dim = ax.size
84+
elif isinstance(ax, BatchAxis):
85+
dim = Dim("batch", min=1)
8586
elif isinstance(ax.size, ParameterizedSize):
8687
dim = Dim(dim_name, min=ax.size.min)
8788
elif isinstance(ax.size, SizeReference):

tests/test_weight_converters.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,15 @@
44
from pathlib import Path
55

66
import pytest
7-
8-
from bioimageio.spec import load_description
7+
from bioimageio.spec import load_model_description
98
from bioimageio.spec.model import v0_5
109

1110

1211
def test_pytorch_to_torchscript(any_torch_model, tmp_path):
1312
from bioimageio.core import test_model
1413
from bioimageio.core.weight_converters.pytorch_to_torchscript import convert
1514

16-
model_descr = load_description(any_torch_model, perform_io_checks=False)
15+
model_descr = load_model_description(any_torch_model, perform_io_checks=False)
1716
if model_descr.implemented_format_version_tuple[:2] == (0, 4):
1817
pytest.skip("cannot convert to old 0.4 format")
1918

@@ -31,9 +30,9 @@ def test_pytorch_to_onnx(convert_to_onnx, tmp_path):
3130
from bioimageio.core import test_model
3231
from bioimageio.core.weight_converters.pytorch_to_onnx import convert
3332

34-
model_descr = load_description(convert_to_onnx, format_version="latest")
33+
model_descr = load_model_description(convert_to_onnx, format_version="latest")
3534
out_path = tmp_path / "weights.onnx"
36-
opset_version = 15
35+
opset_version = 18
3736
ret_val = convert(
3837
model_descr=model_descr,
3938
output_path=out_path,
@@ -55,7 +54,7 @@ def test_keras_to_tensorflow(any_keras_model: Path, tmp_path: Path):
5554
from bioimageio.core.weight_converters.keras_to_tensorflow import convert
5655

5756
out_path = tmp_path / "weights.zip"
58-
model_descr = load_description(any_keras_model)
57+
model_descr = load_model_description(any_keras_model)
5958
ret_val = convert(model_descr, out_path)
6059

6160
assert out_path.exists()
@@ -75,7 +74,7 @@ def test_keras_to_tensorflow(any_keras_model: Path, tmp_path: Path):
7574
# def test_tensorflow_to_keras(any_tensorflow_model: Path, tmp_path: Path):
7675
# from bioimageio.core.weight_converters.tensorflow_to_keras import convert
7776

78-
# model_descr = load_description(any_tensorflow_model)
77+
# model_descr = load_model_description(any_tensorflow_model)
7978
# out_path = tmp_path / "weights.h5"
8079
# ret_val = convert(model_descr, output_path=out_path)
8180
# assert out_path.exists()
@@ -92,7 +91,7 @@ def test_keras_to_tensorflow(any_keras_model: Path, tmp_path: Path):
9291
# from bioimageio.core.weight_converters.tensorflow_to_keras import convert
9392

9493
# out_path = tmp_path / "weights.zip"
95-
# model_descr = load_description(any_tensorflow_model)
94+
# model_descr = load_model_description(any_tensorflow_model)
9695
# ret_val = convert(model_descr, out_path)
9796

9897
# assert out_path.exists()

0 commit comments

Comments
 (0)