Skip to content

Commit 7f6fdf1

Browse files
committed
WIP refactor backend libs
1 parent b7d5f98 commit 7f6fdf1

15 files changed

+555
-605
lines changed
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
import warnings
2+
from abc import abstractmethod
3+
from typing import List, Optional, Sequence, Tuple, Union, final
4+
5+
from bioimageio.spec.model import v0_4, v0_5
6+
7+
from ._model_adapter import (
8+
DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER,
9+
ModelAdapter,
10+
WeightsFormat,
11+
)
12+
from .tensor import Tensor
13+
14+
15+
def create_model_adapter(
16+
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
17+
*,
18+
devices: Optional[Sequence[str]] = None,
19+
weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
20+
):
21+
"""
22+
Creates model adapter based on the passed spec
23+
Note: All specific adapters should happen inside this function to prevent different framework
24+
initializations interfering with each other
25+
"""
26+
if not isinstance(model_description, (v0_4.ModelDescr, v0_5.ModelDescr)):
27+
raise TypeError(
28+
f"expected v0_4.ModelDescr or v0_5.ModelDescr, but got {type(model_description)}"
29+
)
30+
31+
weights = model_description.weights
32+
errors: List[Tuple[WeightsFormat, Exception]] = []
33+
weight_format_priority_order = (
34+
DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
35+
if weight_format_priority_order is None
36+
else weight_format_priority_order
37+
)
38+
# limit weight formats to the ones present
39+
weight_format_priority_order = [
40+
w for w in weight_format_priority_order if getattr(weights, w) is not None
41+
]
42+
43+
for wf in weight_format_priority_order:
44+
if wf == "pytorch_state_dict" and weights.pytorch_state_dict is not None:
45+
try:
46+
from .model_adapters_old._pytorch_model_adapter import (
47+
PytorchModelAdapter,
48+
)
49+
50+
return PytorchModelAdapter(
51+
outputs=model_description.outputs,
52+
weights=weights.pytorch_state_dict,
53+
devices=devices,
54+
)
55+
except Exception as e:
56+
errors.append((wf, e))
57+
elif (
58+
wf == "tensorflow_saved_model_bundle"
59+
and weights.tensorflow_saved_model_bundle is not None
60+
):
61+
try:
62+
from .model_adapters_old._tensorflow_model_adapter import (
63+
TensorflowModelAdapter,
64+
)
65+
66+
return TensorflowModelAdapter(
67+
model_description=model_description, devices=devices
68+
)
69+
except Exception as e:
70+
errors.append((wf, e))
71+
elif wf == "onnx" and weights.onnx is not None:
72+
try:
73+
from .model_adapters_old._onnx_model_adapter import ONNXModelAdapter
74+
75+
return ONNXModelAdapter(
76+
model_description=model_description, devices=devices
77+
)
78+
except Exception as e:
79+
errors.append((wf, e))
80+
elif wf == "torchscript" and weights.torchscript is not None:
81+
try:
82+
from .model_adapters_old._torchscript_model_adapter import (
83+
TorchscriptModelAdapter,
84+
)
85+
86+
return TorchscriptModelAdapter(
87+
model_description=model_description, devices=devices
88+
)
89+
except Exception as e:
90+
errors.append((wf, e))
91+
elif wf == "keras_hdf5" and weights.keras_hdf5 is not None:
92+
# keras can either be installed as a separate package or used as part of tensorflow
93+
# we try to first import the keras model adapter using the separate package and,
94+
# if it is not available, try to load the one using tf
95+
try:
96+
from .backend.keras import (
97+
KerasModelAdapter,
98+
keras, # type: ignore
99+
)
100+
101+
if keras is None:
102+
from .model_adapters_old._tensorflow_model_adapter import (
103+
KerasModelAdapter,
104+
)
105+
106+
return KerasModelAdapter(
107+
model_description=model_description, devices=devices
108+
)
109+
except Exception as e:
110+
errors.append((wf, e))
111+
112+
assert errors
113+
if len(weight_format_priority_order) == 1:
114+
assert len(errors) == 1
115+
raise ValueError(
116+
f"The '{weight_format_priority_order[0]}' model adapter could not be created"
117+
+ f" in this environment:\n{errors[0][1].__class__.__name__}({errors[0][1]}).\n\n"
118+
) from errors[0][1]
119+
120+
else:
121+
error_list = "\n - ".join(
122+
f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
123+
)
124+
raise ValueError(
125+
"None of the weight format specific model adapters could be created"
126+
+ f" in this environment. Errors are:\n\n{error_list}.\n\n"
127+
)

bioimageio/core/_model_adapter.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import warnings
2+
from abc import ABC, abstractmethod
3+
from typing import List, Optional, Sequence, Tuple, Union, final
4+
5+
from bioimageio.spec.model import v0_4, v0_5
6+
7+
from .tensor import Tensor
8+
9+
WeightsFormat = Union[v0_4.WeightsFormat, v0_5.WeightsFormat]
10+
11+
__all__ = [
12+
"ModelAdapter",
13+
"create_model_adapter",
14+
"get_weight_formats",
15+
]
16+
17+
# Known weight formats in order of priority
18+
# First match wins
19+
DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER: Tuple[WeightsFormat, ...] = (
20+
"pytorch_state_dict",
21+
"tensorflow_saved_model_bundle",
22+
"torchscript",
23+
"onnx",
24+
"keras_hdf5",
25+
)
26+
27+
28+
class ModelAdapter(ABC):
29+
"""
30+
Represents model *without* any preprocessing or postprocessing.
31+
32+
```
33+
from bioimageio.core import load_description
34+
35+
model = load_description(...)
36+
37+
# option 1:
38+
adapter = ModelAdapter.create(model)
39+
adapter.forward(...)
40+
adapter.unload()
41+
42+
# option 2:
43+
with ModelAdapter.create(model) as adapter:
44+
adapter.forward(...)
45+
```
46+
"""
47+
48+
@final
49+
@classmethod
50+
def create(
51+
cls,
52+
model_description: Union[v0_4.ModelDescr, v0_5.ModelDescr],
53+
*,
54+
devices: Optional[Sequence[str]] = None,
55+
weight_format_priority_order: Optional[Sequence[WeightsFormat]] = None,
56+
):
57+
"""
58+
Creates model adapter based on the passed spec
59+
Note: All specific adapters should happen inside this function to prevent different framework
60+
initializations interfering with each other
61+
"""
62+
from ._create_model_adapter import create_model_adapter
63+
64+
return create_model_adapter(
65+
model_description,
66+
devices=devices,
67+
weight_format_priority_order=weight_format_priority_order,
68+
)
69+
70+
@final
71+
def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
72+
warnings.warn("Deprecated. ModelAdapter is loaded on initialization")
73+
74+
@abstractmethod
75+
def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
76+
"""
77+
Run forward pass of model to get model predictions
78+
"""
79+
# TODO: handle tensor.transpose in here and make _forward_impl the abstract impl
80+
81+
@abstractmethod
82+
def unload(self):
83+
"""
84+
Unload model from any devices, freeing their memory.
85+
The moder adapter should be considered unusable afterwards.
86+
"""
87+
88+
89+
def get_weight_formats() -> List[str]:
90+
"""
91+
Return list of supported weight types
92+
"""
93+
return list(DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER)

bioimageio/core/backend/__init__.py

Whitespace-only changes.

bioimageio/core/model_adapters/_keras_model_adapter.py renamed to bioimageio/core/backend/keras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111
from .._settings import settings
1212
from ..digest_spec import get_axes_infos
13+
from ..model_adapters import ModelAdapter
1314
from ..tensor import Tensor
14-
from ._model_adapter import ModelAdapter
1515

1616
os.environ["KERAS_BACKEND"] = settings.keras_backend
1717

0 commit comments

Comments
 (0)