Skip to content

Commit e335077

Browse files
authored
Merge pull request #284 from bioimage-io/raw_pred
allow to create prediction pipeline from raw_nodes.Model
2 parents 7cfa502 + f3a4aa3 commit e335077

File tree

4 files changed

+54
-20
lines changed

4 files changed

+54
-20
lines changed

bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
import abc
2-
from typing import List, Optional, Sequence, Type
2+
from typing import List, Optional, Sequence, Type, Union
33

44
import xarray as xr
55

6+
from bioimageio.core import load_resource_description
67
from bioimageio.core.resource_io import nodes
78

89
#: Known weight formats in order of priority
910
#: First match wins
11+
from bioimageio.spec.model import raw_nodes
12+
1013
_WEIGHT_FORMATS = ["pytorch_state_dict", "tensorflow_saved_model_bundle", "torchscript", "onnx", "keras_hdf5"]
1114

1215

@@ -15,11 +18,24 @@ class ModelAdapter(abc.ABC):
1518
Represents model *without* any preprocessing and postprocessing
1619
"""
1720

18-
def __init__(self, *, bioimageio_model: nodes.Model, devices: Optional[Sequence[str]] = None):
19-
self.bioimageio_model = bioimageio_model
21+
def __init__(
22+
self, *, bioimageio_model: Union[nodes.Model, raw_nodes.Model], devices: Optional[Sequence[str]] = None
23+
):
24+
self.bioimageio_model = self._prepare_model(bioimageio_model)
2025
self.default_devices = devices
2126
self.loaded = False
2227

28+
@staticmethod
29+
def _prepare_model(bioimageio_model):
30+
"""the (raw) model node is prepared (here: loaded as non-raw model node) for the model adapter to be ready
31+
for operation.
32+
Note: To write a model adapter that uses the raw model node one can overwrite this method.
33+
"""
34+
if isinstance(bioimageio_model, nodes.Model):
35+
return bioimageio_model
36+
else:
37+
return load_resource_description(bioimageio_model)
38+
2339
def __enter__(self):
2440
"""load on entering context"""
2541
assert not self.loaded
@@ -93,7 +109,10 @@ def get_weight_formats() -> List[str]:
93109

94110

95111
def create_model_adapter(
96-
*, bioimageio_model: nodes.Model, devices=Optional[Sequence[str]], weight_format: Optional[str] = None
112+
*,
113+
bioimageio_model: Union[nodes.Model, raw_nodes.Model],
114+
devices=Optional[Sequence[str]],
115+
weight_format: Optional[str] = None,
97116
) -> ModelAdapter:
98117
"""
99118
Creates model adapter based on the passed spec
@@ -105,7 +124,7 @@ def create_model_adapter(
105124

106125
if weight_format is not None:
107126
if weight_format not in weight_formats:
108-
raise ValueError(f"Weight format {weight_format} is not in supported formats {_WEIGHT_FORMATS}")
127+
raise ValueError(f"Weight format {weight_format} is not in supported formats {weight_formats}")
109128
weight_formats = [weight_format]
110129

111130
for weight in weight_formats:

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
import abc
22
from dataclasses import dataclass
3-
from typing import Iterable, List, Optional, Sequence, Tuple
3+
from typing import Iterable, List, Optional, Sequence, Tuple, Union
44

55
import xarray as xr
66
from marshmallow import missing
77

88
from bioimageio.core.resource_io import nodes
9-
from bioimageio.core.resource_io.nodes import InputTensor, Model, OutputTensor
9+
from bioimageio.spec.model import raw_nodes
1010
from ._combined_processing import CombinedProcessing
1111
from ._model_adapters import ModelAdapter, create_model_adapter
1212
from ._stat_state import StatsState
1313
from ._utils import ComputedMeasures, Sample, TensorName
14+
from .. import load_resource_description
15+
from ..resource_io.utils import resolve_raw_node
1416

1517

1618
@dataclass
@@ -54,15 +56,15 @@ def name(self) -> str:
5456

5557
@property
5658
@abc.abstractmethod
57-
def input_specs(self) -> List[InputTensor]:
59+
def input_specs(self) -> List[nodes.InputTensor]:
5860
"""
5961
specs of inputs
6062
"""
6163
...
6264

6365
@property
6466
@abc.abstractmethod
65-
def output_specs(self) -> List[OutputTensor]:
67+
def output_specs(self) -> List[nodes.OutputTensor]:
6668
"""
6769
specs of outputs
6870
"""
@@ -88,7 +90,7 @@ def __init__(
8890
self,
8991
*,
9092
name: str,
91-
bioimageio_model: Model,
93+
bioimageio_model: Union[nodes.Model, raw_nodes.Model],
9294
preprocessing: CombinedProcessing,
9395
postprocessing: CombinedProcessing,
9496
ipt_stats: StatsState,
@@ -99,8 +101,14 @@ def __init__(
99101
raise NotImplementedError(f"Not yet implemented inference for run mode '{bioimageio_model.run_mode.name}'")
100102

101103
self._name = name
102-
self._input_specs = bioimageio_model.inputs
103-
self._output_specs = bioimageio_model.outputs
104+
if isinstance(bioimageio_model, nodes.Model):
105+
self._input_specs = bioimageio_model.inputs
106+
self._output_specs = bioimageio_model.outputs
107+
else:
108+
assert isinstance(bioimageio_model, raw_nodes.Model)
109+
self._input_specs = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs]
110+
self._output_specs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs]
111+
104112
self._preprocessing = preprocessing
105113
self._postprocessing = postprocessing
106114
self._ipt_stats = ipt_stats
@@ -176,7 +184,7 @@ def unload(self):
176184

177185

178186
def create_prediction_pipeline(
179-
bioimageio_model: nodes.Model,
187+
bioimageio_model: Union[nodes.Model, raw_nodes.Model],
180188
*,
181189
devices: Optional[Sequence[str]] = None,
182190
weight_format: Optional[str] = None,
@@ -196,8 +204,16 @@ def create_prediction_pipeline(
196204
model_adapter: ModelAdapter = model_adapter or create_model_adapter(
197205
bioimageio_model=bioimageio_model, devices=devices, weight_format=weight_format
198206
)
207+
if isinstance(bioimageio_model, nodes.Model):
208+
ipts = bioimageio_model.inputs
209+
outs = bioimageio_model.outputs
210+
211+
else:
212+
assert isinstance(bioimageio_model, raw_nodes.Model)
213+
ipts = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs]
214+
outs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs]
199215

200-
preprocessing = CombinedProcessing(bioimageio_model.inputs)
216+
preprocessing = CombinedProcessing(ipts)
201217

202218
def sample_dataset():
203219
for tensors in dataset_for_initial_statistics:
@@ -209,7 +225,7 @@ def sample_dataset():
209225
update_dataset_stats_after_n_samples=update_dataset_stats_after_n_samples,
210226
update_dataset_stats_for_n_samples=update_dataset_stats_for_n_samples,
211227
)
212-
postprocessing = CombinedProcessing(bioimageio_model.outputs)
228+
postprocessing = CombinedProcessing(outs)
213229
out_stats = StatsState(
214230
postprocessing.required_measures,
215231
dataset=tuple(),

bioimageio/core/resource_io/io_.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
)
2020
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
2121
from . import nodes
22-
from .utils import resolve_raw_resource_description, resolve_source
22+
from .utils import resolve_raw_node, resolve_source
2323

2424
serialize_raw_resource_description = spec.io_.serialize_raw_resource_description
2525
save_raw_resource_description = spec.io_.save_raw_resource_description
@@ -54,7 +54,7 @@ def load_resource_description(
5454
else:
5555
raise ValueError(f"Not found any of the specified weights formats {weights_priority_order}")
5656

57-
rd: ResourceDescription = resolve_raw_resource_description(raw_rd=raw_rd, nodes_module=nodes)
57+
rd: ResourceDescription = resolve_raw_node(raw_rd=raw_rd, nodes_module=nodes)
5858
assert isinstance(rd, getattr(nodes, get_class_name_from_type(raw_rd.type)))
5959

6060
return rd

bioimageio/core/resource_io/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from bioimageio.spec.shared import raw_nodes, resolve_source, source_available
1010
from bioimageio.spec.shared.node_transformer import (
1111
GenericRawNode,
12-
GenericRawRD,
1312
GenericResolvedNode,
1413
NodeTransformer,
1514
NodeVisitor,
@@ -105,8 +104,8 @@ def all_sources_available(
105104
return True
106105

107106

108-
def resolve_raw_resource_description(
109-
raw_rd: GenericRawRD, nodes_module: typing.Any, uri_only_if_in_package: bool = True
107+
def resolve_raw_node(
108+
raw_rd: GenericRawNode, nodes_module: typing.Any, uri_only_if_in_package: bool = True
110109
) -> GenericResolvedNode:
111110
"""resolve all uris and paths (that are included when packaging)"""
112111
rd = UriNodeTransformer(root_path=raw_rd.root_path, uri_only_if_in_package=uri_only_if_in_package).transform(raw_rd)

0 commit comments

Comments
 (0)