Skip to content

Commit 3275a99

Browse files
authored
Merge pull request #399 from bioimage-io/improve_pytorch_adapter
Improve model adapters
2 parents 66665c5 + 81ea7db commit 3275a99

File tree

10 files changed

+117
-85
lines changed

10 files changed

+117
-85
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import traceback
22
import warnings
3+
from itertools import product
34
from typing import Dict, Hashable, List, Literal, Optional, Set, Tuple, Union
45

56
import numpy as np
@@ -179,31 +180,38 @@ def _test_model_inference_parametrized(
179180
model: v0_5.ModelDescr,
180181
weight_format: Optional[WeightsFormat],
181182
devices: Optional[List[str]],
182-
test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
183-
(0, 2),
184-
(1, 3),
185-
(2, 1),
186-
(3, 2),
187-
},
188183
) -> None:
189-
if not test_cases:
190-
return
191-
192-
logger.info(
193-
"Testing inference with {} different input tensor sizes", len(test_cases)
194-
)
195-
196184
if not any(
197185
isinstance(a.size, v0_5.ParameterizedSize)
198186
for ipt in model.inputs
199187
for a in ipt.axes
200188
):
201189
# no parameterized sizes => set n=0
202-
test_cases = {(0, b) for _n, b in test_cases}
190+
ns: Set[v0_5.ParameterizedSize.N] = {0}
191+
else:
192+
ns = {0, 1, 2}
203193

204-
if not any(isinstance(a, v0_5.BatchAxis) for ipt in model.inputs for a in ipt.axes):
205-
# no batch axis => set b=1
206-
test_cases = {(n, 1) for n, _b in test_cases}
194+
given_batch_sizes = {
195+
a.size
196+
for ipt in model.inputs
197+
for a in ipt.axes
198+
if isinstance(a, v0_5.BatchAxis)
199+
}
200+
if given_batch_sizes:
201+
batch_sizes = {gbs for gbs in given_batch_sizes if gbs is not None}
202+
if not batch_sizes:
203+
# only arbitrary batch sizes
204+
batch_sizes = {1, 2}
205+
else:
206+
# no batch axis
207+
batch_sizes = {1}
208+
209+
test_cases: Set[Tuple[v0_5.ParameterizedSize.N, BatchSize]] = {
210+
(n, b) for n, b in product(sorted(ns), sorted(batch_sizes))
211+
}
212+
logger.info(
213+
"Testing inference with {} different input tensor sizes", len(test_cases)
214+
)
207215

208216
def generate_test_cases():
209217
tested: Set[Hashable] = set()

bioimageio/core/model_adapters/_pytorch_model_adapter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242

4343
self._primary_device = self._devices[0]
4444
state: Any = torch.load(
45-
download(weights.source).path,
45+
download(weights).path,
4646
map_location=self._primary_device, # pyright: ignore[reportUnknownArgumentType]
4747
)
4848
self._network.load_state_dict(state)

bioimageio/core/model_adapters/_tensorflow_model_adapter.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import warnings
21
import zipfile
32
from typing import List, Literal, Optional, Sequence, Union
43

54
import numpy as np
5+
from loguru import logger
66

77
from bioimageio.spec.common import FileSource
88
from bioimageio.spec.model import v0_4, v0_5
@@ -46,19 +46,19 @@ def __init__(
4646
)
4747
model_tf_version = weights.tensorflow_version
4848
if model_tf_version is None:
49-
warnings.warn(
49+
logger.warning(
5050
"The model does not specify the tensorflow version."
5151
+ f"Cannot check if it is compatible with intalled tensorflow {tf_version}."
5252
)
5353
elif model_tf_version > tf_version:
54-
warnings.warn(
54+
logger.warning(
5555
f"The model specifies a newer tensorflow version than installed: {model_tf_version} > {tf_version}."
5656
)
5757
elif (model_tf_version.major, model_tf_version.minor) != (
5858
tf_version.major,
5959
tf_version.minor,
6060
):
61-
warnings.warn(
61+
logger.warning(
6262
"The tensorflow version specified by the model does not match the installed: "
6363
+ f"{model_tf_version} != {tf_version}."
6464
)
@@ -70,7 +70,7 @@ def __init__(
7070

7171
# TODO tf device management
7272
if devices is not None:
73-
warnings.warn(
73+
logger.warning(
7474
f"Device management is not implemented for tensorflow yet, ignoring the devices {devices}"
7575
)
7676

@@ -98,9 +98,20 @@ def _get_network( # pyright: ignore[reportUnknownParameterType]
9898
weight_file = self.require_unzipped(weight_file)
9999
assert tf is not None
100100
if self.use_keras_api:
101-
return tf.keras.models.load_model(
102-
weight_file, compile=False
103-
) # pyright: ignore[reportUnknownVariableType]
101+
try:
102+
return tf.keras.layers.TFSMLayer(
103+
weight_file, call_endpoint="serve"
104+
) # pyright: ignore[reportUnknownVariableType]
105+
except Exception as e:
106+
try:
107+
return tf.keras.layers.TFSMLayer(
108+
weight_file, call_endpoint="serving_default"
109+
) # pyright: ignore[reportUnknownVariableType]
110+
except Exception as ee:
111+
logger.opt(exception=ee).info(
112+
"keras.layers.TFSMLayer error for alternative call_endpoint='serving_default'"
113+
)
114+
raise e
104115
else:
105116
# NOTE in tf1 the model needs to be loaded inside of the session, so we cannot preload the model
106117
return str(weight_file)
@@ -189,24 +200,15 @@ def _forward_keras( # pyright: ignore[reportUnknownParameterType]
189200
None if ipt is None else tf.convert_to_tensor(ipt) for ipt in input_tensors
190201
]
191202

192-
try:
193-
result = ( # pyright: ignore[reportUnknownVariableType]
194-
self._network.forward(*tf_tensor)
195-
)
196-
except AttributeError:
197-
result = ( # pyright: ignore[reportUnknownVariableType]
198-
self._network.predict(*tf_tensor)
199-
)
203+
result = self._network(*tf_tensor) # pyright: ignore[reportUnknownVariableType]
200204

201-
if not isinstance(result, (tuple, list)):
202-
result = [result] # pyright: ignore[reportUnknownVariableType]
205+
assert isinstance(result, dict)
206+
207+
# TODO: Use RDF's `outputs[i].id` here
208+
result = list(result.values())
203209

204210
return [ # pyright: ignore[reportUnknownVariableType]
205-
(
206-
None
207-
if r is None
208-
else r if isinstance(r, np.ndarray) else tf.make_ndarray(r)
209-
)
211+
(None if r is None else r if isinstance(r, np.ndarray) else r.numpy())
210212
for r in result # pyright: ignore[reportUnknownVariableType]
211213
]
212214

@@ -230,7 +232,7 @@ def forward(self, *input_tensors: Optional[Tensor]) -> List[Optional[Tensor]]:
230232
]
231233

232234
def unload(self) -> None:
233-
warnings.warn(
235+
logger.warning(
234236
"Device management is not implemented for keras yet, cannot unload model"
235237
)
236238

dev/env-py38.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- bioimageio.spec>=0.5.2.post1
7+
- bioimageio.spec>=0.5.3
88
- black
99
- crick # uncommented
1010
- filelock

dev/env-tf.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- bioimageio.spec>=0.5.2.post1
7+
- bioimageio.spec>=0.5.3
88
- black
99
# - crick # currently requires python<=3.9
1010
- filelock

dev/env-wo-python.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ channels:
44
- conda-forge
55
- defaults
66
dependencies:
7-
- bioimageio.spec>=0.5.2.post1
7+
- bioimageio.spec>=0.5.3
88
- black
99
# - crick # currently requires python<=3.9
1010
- filelock

dev/env.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ channels:
33
- conda-forge
44
- defaults
55
dependencies:
6-
- bioimageio.spec>=0.5.2.post1
6+
- bioimageio.spec>=0.5.3
77
- black
88
# - crick # currently requires python<=3.9
99
- filelock

0 commit comments

Comments
 (0)