Skip to content

Commit aaa4c22

Browse files
authored
Merge pull request #155 from bioimage-io/unload
add unload for PredictionPipeline and ModelAdapters
2 parents 29e28d7 + 30ce61c commit aaa4c22

File tree

9 files changed

+276
-52
lines changed

9 files changed

+276
-52
lines changed

bioimageio/core/prediction_pipeline/_model_adapters/_keras_model_adapter.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
11
import warnings
2-
from typing import List, Optional
2+
from typing import List, Optional, Sequence
33

44
import keras
55
import xarray as xr
66

7-
from bioimageio.core.resource_io import nodes
87
from ._model_adapter import ModelAdapter
98

109

1110
class KerasModelAdapter(ModelAdapter):
12-
def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[List[str]] = None):
11+
def _load(self, *, devices: Optional[Sequence[str]] = None) -> None:
1312
# TODO keras device management
1413
if devices is not None:
15-
warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}")
14+
warnings.warn(f"Device management is not implemented for keras yet, ignoring the devices {devices}")
1615

17-
weight_file = bioimageio_model.weights["keras_hdf5"].source
16+
weight_file = self.bioimageio_model.weights["keras_hdf5"].source
1817
self._model = keras.models.load_model(weight_file)
19-
self._output_axes = [tuple(out.axes) for out in bioimageio_model.outputs]
18+
self._output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs]
2019

21-
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
20+
def _unload(self) -> None:
21+
warnings.warn("Device management is not implemented for keras yet, cannot unload model")
22+
23+
def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
2224
result = self._model.predict(*input_tensors)
2325
if not isinstance(result, (tuple, list)):
2426
result = [result]

bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py

Lines changed: 62 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import abc
2-
from typing import List, Optional, Type
2+
from typing import List, Optional, Sequence, Type
33

44
import xarray as xr
55

66
from bioimageio.core.resource_io import nodes
77

8-
#: Known weigh types in order of priority
8+
#: Known weight formats in order of priority
99
#: First match wins
1010
_WEIGHT_FORMATS = ["pytorch_state_dict", "tensorflow_saved_model_bundle", "pytorch_script", "onnx", "keras_hdf5"]
1111

@@ -15,19 +15,75 @@ class ModelAdapter(abc.ABC):
1515
Represents model *without* any preprocessing and postprocessing
1616
"""
1717

18+
def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[Sequence[str]] = None):
19+
self.bioimageio_model = bioimageio_model
20+
self.default_devices = devices
21+
self.loaded = False
22+
23+
def __enter__(self):
24+
"""load on entering context"""
25+
assert not self.loaded
26+
self.load() # using default_devices
27+
return self
28+
29+
def __exit__(self, exc_type, exc_val, exc_tb):
30+
"""unload on exiting context"""
31+
assert self.loaded
32+
self.unload()
33+
return False
34+
35+
def load(self, *, devices: Optional[Sequence[str]] = None) -> None:
36+
"""
37+
Note: Use ModelAdapter as context to not worry about calling unload()!
38+
Load model onto devices. If devices is None, self.default_devices are chosen
39+
(which may be None as well, in which case a framework dependent default is chosen)
40+
"""
41+
self._load(devices=devices or self.default_devices)
42+
self.loaded = True
43+
1844
@abc.abstractmethod
19-
def __init__(self, *, bioimageio_model: nodes.Model, devices=Optional[List[str]]):
45+
def _load(self, *, devices: Optional[Sequence[str]] = None) -> None:
46+
"""
47+
Load model onto devices. If devices is None a framework dependent default is chosen
48+
"""
2049
...
2150

22-
@abc.abstractmethod
2351
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
52+
"""
53+
Load model if unloaded/outside context; then run forward pass of model to get model predictions
54+
"""
55+
if not self.loaded:
56+
self.load()
57+
58+
assert self.loaded
59+
return self._forward(*input_tensors)
60+
61+
@abc.abstractmethod
62+
def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
2463
"""
2564
Run forward pass of model to get model predictions
2665
Note: model is responsible converting it's data representation to
2766
xarray.DataArray
2867
"""
2968
...
3069

70+
def unload(self):
71+
"""
72+
Unload model from any devices, freeing their memory.
73+
Note: Use ModelAdapter as context to not worry about calling unload()!
74+
"""
75+
# implementation of non-state-machine logic in _unload()
76+
assert self.loaded
77+
self._unload()
78+
self.loaded = False
79+
80+
@abc.abstractmethod
81+
def _unload(self) -> None:
82+
"""
83+
Unload model from any devices, freeing their memory.
84+
"""
85+
...
86+
3187

3288
def get_weight_formats() -> List[str]:
3389
"""
@@ -37,14 +93,13 @@ def get_weight_formats() -> List[str]:
3793

3894

3995
def create_model_adapter(
40-
*, bioimageio_model: nodes.Model, devices=Optional[List[str]], weight_format: Optional[str] = None
96+
*, bioimageio_model: nodes.Model, devices=Optional[Sequence[str]], weight_format: Optional[str] = None
4197
) -> ModelAdapter:
4298
"""
4399
Creates model adapter based on the passed spec
44100
Note: All specific adapters should happen inside this function to prevent different framework
45101
initializations interfering with each other
46102
"""
47-
spec = bioimageio_model
48103
weights = bioimageio_model.weights
49104
weight_formats = get_weight_formats()
50105

@@ -59,7 +114,7 @@ def create_model_adapter(
59114
return adapter_cls(bioimageio_model=bioimageio_model, devices=devices)
60115

61116
raise RuntimeError(
62-
f"weight format {weight_format} not among weight formats listed in model: {list(spec.weights.keys())}"
117+
f"weight format {weight_format} not among weight formats listed in model: {list(bioimageio_model.weights.keys())}"
63118
)
64119

65120

bioimageio/core/prediction_pipeline/_model_adapters/_onnx_model_adapter.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,30 +5,30 @@
55
import onnxruntime as rt
66
import xarray as xr
77

8-
from bioimageio.core.resource_io import nodes
98
from ._model_adapter import ModelAdapter
109

1110
logger = logging.getLogger(__name__)
1211

1312

1413
class ONNXModelAdapter(ModelAdapter):
15-
def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[List[str]] = None):
16-
spec = bioimageio_model
14+
def _load(self, *, devices: Optional[List[str]] = None):
15+
self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs]
1716

18-
self._internal_output_axes = [tuple(out.axes) for out in bioimageio_model.outputs]
19-
20-
self._session = rt.InferenceSession(str(spec.weights["onnx"].source))
17+
self._session = rt.InferenceSession(str(self.bioimageio_model.weights["onnx"].source))
2118
onnx_inputs = self._session.get_inputs()
2219
self._input_names = [ipt.name for ipt in onnx_inputs]
2320

2421
if devices is not None:
2522
warnings.warn(f"Device management is not implemented for onnx yet, ignoring the devices {devices}")
2623

27-
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
24+
def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
2825
assert len(input_tensors) == len(self._input_names)
2926
input_arrays = [ipt.data for ipt in input_tensors]
3027
result = self._session.run(None, dict(zip(self._input_names, input_arrays)))
3128
if not isinstance(result, (list, tuple)):
3229
result = []
3330

3431
return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)]
32+
33+
def _unload(self) -> None:
34+
warnings.warn("Device management is not implemented for onnx yet, cannot unload model")

bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import gc
2+
import warnings
13
from typing import List, Optional
24

35
import torch
@@ -9,24 +11,28 @@
911

1012

1113
class PytorchModelAdapter(ModelAdapter):
12-
def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[List[str]] = None):
13-
self._model = self.get_nn_instance(bioimageio_model)
14+
def _load(self, *, devices: Optional[List[str]] = None):
15+
self._model = self.get_nn_instance(self.bioimageio_model)
1416

1517
if devices is None:
1618
self._devices = ["cuda" if torch.cuda.is_available() else "cpu"]
1719
else:
1820
self._devices = [torch.device(d) for d in devices]
21+
22+
if len(self._devices) > 1:
23+
warnings.warn("Multiple devices for single pytorch model not yet implemented")
24+
1925
self._model.to(self._devices[0])
2026

2127
assert isinstance(self._model, torch.nn.Module)
22-
weights = bioimageio_model.weights.get("pytorch_state_dict")
28+
weights = self.bioimageio_model.weights.get("pytorch_state_dict")
2329
if weights is not None and weights.source:
2430
state = torch.load(weights.source, map_location=self._devices[0])
2531
self._model.load_state_dict(state)
2632

27-
self._internal_output_axes = [tuple(out.axes) for out in bioimageio_model.outputs]
33+
self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs]
2834

29-
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
35+
def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
3036
with torch.no_grad():
3137
tensors = [torch.from_numpy(ipt.data) for ipt in input_tensors]
3238
tensors = [t.to(self._devices[0]) for t in tensors]
@@ -38,6 +44,12 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
3844

3945
return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)]
4046

47+
def _unload(self) -> None:
48+
self._devices = None
49+
del self._model
50+
gc.collect() # deallocate memory
51+
torch.cuda.empty_cache() # release reserved memory
52+
4153
@staticmethod
4254
def get_nn_instance(model_node: nodes.Model, **kwargs):
4355
assert isinstance(model_node.source, nodes.ImportedSource)

bioimageio/core/prediction_pipeline/_model_adapters/_tensorflow_model_adapter.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
from bioimageio.core.resource_io import nodes
1010
from ._model_adapter import ModelAdapter
1111

12+
try:
13+
from typing import Literal
14+
except ImportError:
15+
from typing_extensions import Literal # type: ignore
16+
1217

1318
class TensorflowModelAdapterBase(ModelAdapter):
19+
weight_format: Literal["keras_hdf5", "tensorflow_saved_model_bundle"]
20+
1421
def require_unzipped(self, weight_file):
1522
if zipfile.is_zipfile(weight_file):
1623
out_path = weight_file.with_suffix("")
@@ -27,30 +34,28 @@ def _load_model(self, weight_file):
2734
# NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model
2835
return str(weight_file)
2936

30-
def __init__(self, *, bioimageio_model: nodes.Model, weight_format: str, devices: Optional[List[str]] = None):
31-
self.spec = bioimageio_model
32-
37+
def _load(self, *, devices: Optional[List[str]] = None):
3338
try:
34-
tf_version = self.spec.weights[weight_format].tensorflow_version.version
39+
tf_version = self.bioimageio_model.weights[self.weight_format].tensorflow_version.version
3540
except AttributeError:
3641
tf_version = (1, 14, 0)
3742
tf_major_ver = tf_version[0]
3843
assert tf_major_ver in (1, 2)
39-
self.use_keras_api = tf_major_ver > 1 or weight_format == "keras_hdf5"
44+
self.use_keras_api = tf_major_ver > 1 or self.weight_format == KerasModelAdapter.weight_format
4045

4146
# TODO tf device management
4247
if devices is not None:
4348
warnings.warn(f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}")
4449

45-
weight_file = self.require_unzipped(self.spec.weights[weight_format].source)
50+
weight_file = self.require_unzipped(self.bioimageio_model.weights[self.weight_format].source)
4651
self._model = self._load_model(weight_file)
47-
self._internal_output_axes = [tuple(out.axes) for out in bioimageio_model.outputs]
52+
self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs]
4853

4954
# TODO currently we relaod the model every time. it would be better to keep the graph and session
5055
# alive in between of forward passes (but then the sessions need to be properly opened / closed)
5156
def _forward_tf(self, *input_tensors):
52-
input_keys = [ipt.name for ipt in self.spec.inputs]
53-
output_keys = [out.name for out in self.spec.outputs]
57+
input_keys = [ipt.name for ipt in self.bioimageio_model.inputs]
58+
output_keys = [out.name for out in self.bioimageio_model.outputs]
5459

5560
# TODO read from spec
5661
tag = tf.saved_model.tag_constants.SERVING
@@ -85,7 +90,7 @@ def _forward_keras(self, input_tensors):
8590

8691
return [r if isinstance(r, np.ndarray) else tf.make_ndarray(r) for r in result]
8792

88-
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
93+
def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
8994
data = [ipt.data for ipt in input_tensors]
9095
if self.use_keras_api:
9196
result = self._forward_keras(*data)
@@ -94,14 +99,13 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
9499

95100
return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)]
96101

102+
def _unload(self) -> None:
103+
warnings.warn("Device management is not implemented for keras yet, cannot unload model")
104+
97105

98106
class TensorflowModelAdapter(TensorflowModelAdapterBase):
99-
def __init__(self, *, bioimageio_model: nodes.Model, devices=List[str]):
100-
weight_format = "tensorflow_saved_model_bundle"
101-
super().__init__(bioimageio_model=bioimageio_model, weight_format=weight_format, devices=devices)
107+
weight_format = "tensorflow_saved_model_bundle"
102108

103109

104110
class KerasModelAdapter(TensorflowModelAdapterBase):
105-
def __init__(self, *, bioimageio_model: nodes.Model, devices=List[str]):
106-
weight_format = "keras_hdf5"
107-
super().__init__(bioimageio_model=bioimageio_model, weight_format=weight_format, devices=devices)
111+
weight_format = "keras_hdf5"
Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,30 @@
1+
import gc
2+
import warnings
13
from typing import List, Optional
24

35
import numpy as np
46
import torch
57
import xarray as xr
68

7-
from bioimageio.core.resource_io import nodes
89
from ._model_adapter import ModelAdapter
910

1011

1112
class TorchscriptModelAdapter(ModelAdapter):
12-
def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[List[str]] = None):
13-
weight_path = str(bioimageio_model.weights["pytorch_script"].source.resolve())
13+
def _load(self, *, devices: Optional[List[str]] = None):
14+
weight_path = str(self.bioimageio_model.weights["pytorch_script"].source.resolve())
1415
if devices is None:
1516
self.devices = ["cuda" if torch.cuda.is_available() else "cpu"]
1617
else:
1718
self.devices = [torch.device(d) for d in devices]
1819

20+
if len(self.devices) > 1:
21+
warnings.warn("Multiple devices for single torchscript model not yet implemented")
22+
1923
self._model = torch.jit.load(weight_path)
2024
self._model.to(self.devices[0])
21-
self._internal_output_axes = [tuple(out.axes) for out in bioimageio_model.outputs]
25+
self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs]
2226

23-
def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]:
27+
def _forward(self, *batch: xr.DataArray) -> List[xr.DataArray]:
2428
with torch.no_grad():
2529
torch_tensor = [torch.from_numpy(b.data).to(self.devices[0]) for b in batch]
2630
result = self._model.forward(*torch_tensor)
@@ -31,3 +35,9 @@ def forward(self, *batch: xr.DataArray) -> List[xr.DataArray]:
3135

3236
assert len(result) == len(self._internal_output_axes)
3337
return [xr.DataArray(r, dims=axes) for r, axes in zip(result, self._internal_output_axes)]
38+
39+
def _unload(self) -> None:
40+
self._devices = None
41+
del self._model
42+
gc.collect() # deallocate memory
43+
torch.cuda.empty_cache() # release reserved memory

0 commit comments

Comments
 (0)