Skip to content

Commit 1f5e99e

Browse files
authored
Merge pull request #106 from bioimage-io/clean_pred_pipe
Simplify PredictionPipeline
2 parents 431d758 + b1cf2e4 commit 1f5e99e

File tree

6 files changed

+37
-124
lines changed

6 files changed

+37
-124
lines changed

MANIFEST.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
include bioimageio/core/VERSION

bioimageio/core/prediction.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import xarray as xr
1212

1313
from bioimageio.core import load_resource_description
14-
from bioimageio.core.resource_io.nodes import Model
14+
from bioimageio.core.resource_io.nodes import InputTensor, Model, OutputTensor
1515
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
1616
from tqdm import tqdm
1717

@@ -93,15 +93,19 @@ def pad(im, axes: Sequence[str], padding, pad_right=True) -> Tuple[np.ndarray, D
9393
return im, crop
9494

9595

96-
def load_image(in_path, axes):
96+
def load_image(in_path, axes: Sequence[str]) -> xr.DataArray:
9797
ext = os.path.splitext(in_path)[1]
9898
if ext == ".npy":
9999
im = np.load(in_path)
100100
else:
101101
is_volume = "z" in axes
102102
im = imageio.volread(in_path) if is_volume else imageio.imread(in_path)
103103
im = require_axes(im, axes)
104-
return im
104+
return xr.DataArray(im, dims=axes)
105+
106+
107+
def load_tensors(sources, tensor_specs: List[Union[InputTensor, OutputTensor]]) -> List[xr.DataArray]:
108+
return [load_image(s, sspec.axes) for s, sspec in zip(sources, tensor_specs)]
105109

106110

107111
def _to_channel_last(image):
@@ -239,11 +243,14 @@ def load_tile(tile):
239243
#
240244

241245

242-
def predict(prediction_pipeline, inputs) -> List[xr.DataArray]:
246+
def predict(prediction_pipeline: PredictionPipeline, inputs) -> List[xr.DataArray]:
243247
if not isinstance(inputs, (tuple, list)):
244248
inputs = [inputs]
245249

246-
tagged_data = [xr.DataArray(ipt, dims=axes) for ipt, axes in zip(inputs, prediction_pipeline.input_axes)]
250+
assert len(inputs) == len(prediction_pipeline.input_specs)
251+
tagged_data = [
252+
xr.DataArray(ipt, dims=ipt_spec.axes) for ipt, ipt_spec in zip(inputs, prediction_pipeline.input_specs)
253+
]
247254
return prediction_pipeline.forward(*tagged_data)
248255

249256

@@ -411,10 +418,15 @@ def predict_image(model_rdf, inputs, outputs, padding=None, tiling=None, weight_
411418

412419
padding = parse_padding(padding, model)
413420
tiling = parse_tiling(tiling, model)
421+
422+
_predict_sample(prediction_pipeline, inputs, outputs, padding, tiling)
423+
424+
425+
def _predict_sample(prediction_pipeline, inputs, outputs, padding, tiling):
414426
if padding is not None and tiling is not None:
415427
raise ValueError("Only one of padding or tiling is supported")
416428

417-
input_data = [load_image(inp, axes) for inp, axes in zip(inputs, prediction_pipeline.input_axes)]
429+
input_data = load_tensors(inputs, prediction_pipeline.input_specs)
418430
if padding is not None:
419431
result = predict_with_padding(prediction_pipeline, input_data, padding)
420432
elif tiling is not None:
@@ -439,6 +451,7 @@ def predict_images(
439451
verbose=False,
440452
):
441453
"""Predict multiple inputs with a bioimage.io model."""
454+
442455
model = load_resource_description(model_rdf)
443456
assert isinstance(model, Model)
444457

@@ -448,8 +461,6 @@ def predict_images(
448461

449462
padding = parse_padding(padding, model)
450463
tiling = parse_tiling(tiling, model)
451-
if padding is not None and tiling is not None:
452-
raise ValueError("Only one of padding or tiling is supported")
453464

454465
prog = zip(inputs, outputs)
455466
if verbose:
@@ -462,17 +473,7 @@ def predict_images(
462473
if not isinstance(outp, (tuple, list)):
463474
outp = [outp]
464475

465-
inp = [load_image(im, sp.axes) for im, sp in zip(inp, prediction_pipeline.input_specs)]
466-
if padding is not None:
467-
res = predict_with_padding(prediction_pipeline, inp, padding)
468-
elif tiling is not None:
469-
res = predict_with_tiling(prediction_pipeline, inp, tiling)
470-
else:
471-
res = predict(prediction_pipeline, inp)
472-
473-
assert isinstance(res, list)
474-
for out, r in zip(outp, res):
475-
save_image(out, r)
476+
_predict_sample(prediction_pipeline, inp, outp, padding, tiling)
476477

477478

478479
def test_model(model_rdf: Union[URI, Path, str], weight_format=None, devices=None, decimal=4):

bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import abc
2-
from typing import Any, List, Optional, Type, Union
2+
from typing import List, Optional, Type
33

44
import xarray as xr
5+
56
from bioimageio.core.resource_io import nodes
67

78
#: Known weigh types in order of priority
@@ -18,7 +19,6 @@ class ModelAdapter(abc.ABC):
1819
def __init__(self, *, bioimageio_model: nodes.Model, devices=Optional[List[str]]):
1920
...
2021

21-
# todo: separate preprocessing/actual forward/postprocessing
2222
@abc.abstractmethod
2323
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
2424
"""

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 0 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -61,73 +61,20 @@ def output_specs(self) -> List[OutputTensor]:
6161
"""
6262
...
6363

64-
# todo: replace all uses of properties below with 'input_specs' and 'output_specs'
65-
@property
66-
@abc.abstractmethod
67-
def input_axes(self) -> List[Tuple[str, ...]]:
68-
"""
69-
Input axes excepted by this pipeline
70-
Note: one character axes names
71-
"""
72-
...
73-
74-
@property
75-
@abc.abstractmethod
76-
def input_shape(self) -> List[List[Tuple[str, int]]]:
77-
"""
78-
Named input dimensions
79-
"""
80-
...
81-
82-
@property
83-
@abc.abstractmethod
84-
def output_axes(self) -> List[Tuple[str, ...]]:
85-
"""
86-
Output axes of this pipeline
87-
Note: one character axes names
88-
"""
89-
...
90-
91-
@property
92-
@abc.abstractmethod
93-
def output_shape(self) -> List[Union[List[Tuple[str, float]], NamedImplicitOutputShape]]:
94-
"""
95-
Named output dimensions. Either explicitly defined or implicitly in relation to an input
96-
"""
97-
...
98-
99-
@property
100-
@abc.abstractmethod
101-
def halo(self) -> List[List[Tuple[str, int]]]:
102-
"""
103-
Size of output borders that have unreliable data due to artifacts (after application of postprocessing)
104-
"""
105-
...
106-
10764

10865
class _PredictionPipelineImpl(PredictionPipeline):
10966
def __init__(
11067
self,
11168
*,
11269
name: str,
11370
bioimageio_model: Model,
114-
input_axes: Sequence[str],
115-
input_shape: Sequence[List[Tuple[str, int]]],
116-
output_axes: Sequence[str],
117-
output_shape: Sequence[Union[List[Tuple[str, int]], NamedImplicitOutputShape]],
118-
halo: Sequence[List[Tuple[str, int]]],
11971
preprocessing: Sequence[Transform],
12072
model: ModelAdapter,
12173
postprocessing: Sequence[Transform],
12274
) -> None:
12375
self._name = name
12476
self._input_specs = bioimageio_model.inputs
12577
self._output_specs = bioimageio_model.outputs
126-
self._input_axes = [tuple(axes) for axes in input_axes]
127-
self._input_shape = input_shape
128-
self._output_axes = [tuple(axes) for axes in output_axes]
129-
self._output_shape = output_shape
130-
self._halo = halo
13178
self._preprocessing = preprocessing
13279
self._model: ModelAdapter = model
13380
self._postprocessing = postprocessing
@@ -144,26 +91,6 @@ def input_specs(self):
14491
def output_specs(self):
14592
return self._output_specs
14693

147-
@property
148-
def input_axes(self):
149-
return self._input_axes
150-
151-
@property
152-
def input_shape(self):
153-
return self._input_shape
154-
155-
@property
156-
def output_axes(self):
157-
return self._output_axes
158-
159-
@property
160-
def output_shape(self):
161-
return self._output_shape
162-
163-
@property
164-
def halo(self):
165-
return self._halo
166-
16794
def predict(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
16895
"""Predict input_tensor with the model without applying pre/postprocessing."""
16996
return self._model.forward(*input_tensors)
@@ -230,8 +157,6 @@ def create_prediction_pipeline(
230157
bioimageio_model=bioimageio_model, devices=devices, weight_format=weight_format
231158
)
232159

233-
input_axes: List[str] = []
234-
named_input_shape: List[List[Tuple[str, int]]] = []
235160
preprocessing: List[Transform] = []
236161
for ipt in bioimageio_model.inputs:
237162
try:
@@ -241,42 +166,17 @@ def create_prediction_pipeline(
241166
except AttributeError:
242167
input_shape = ipt.shape
243168

244-
input_axes.append(ipt.axes)
245-
named_input_shape.append(list(zip(ipt.axes, input_shape)))
246169
preprocessing_spec = [] if ipt.preprocessing is missing else ipt.preprocessing.copy()
247170
preprocessing.append(make_preprocessing(preprocessing_spec))
248171

249-
output_axes: List[str] = []
250-
named_output_shape: List[Union[List[Tuple[str, int]], NamedImplicitOutputShape]] = []
251-
named_halo: List[List[Tuple[str, int]]] = []
252172
postprocessing: List[Transform] = []
253173
for out in bioimageio_model.outputs:
254-
output_axes.append(out.axes)
255-
if isinstance(out.shape, list): # explict output shape
256-
named_output_shape.append(list(zip(out.axes, out.shape)))
257-
elif isinstance(out.shape, ImplicitOutputShape):
258-
named_output_shape.append(
259-
NamedImplicitOutputShape(
260-
reference_input=out.shape.reference_tensor,
261-
scale=list(zip(out.axes, out.shape.scale)),
262-
offset=list(zip(out.axes, out.shape.offset)),
263-
)
264-
)
265-
else:
266-
raise TypeError(f"Unexpected type for output shape: {type(out.shape)}")
267-
268-
named_halo.append(list(zip(out.axes, out.halo or [0 for _ in out.axes])))
269174
postprocessing_spec = [] if out.postprocessing is missing else out.postprocessing.copy()
270175
postprocessing.append(make_postprocessing(postprocessing_spec))
271176

272177
return _PredictionPipelineImpl(
273178
name=bioimageio_model.name,
274179
bioimageio_model=bioimageio_model,
275-
input_axes=input_axes,
276-
input_shape=named_input_shape,
277-
output_axes=output_axes,
278-
output_shape=named_output_shape,
279-
halo=named_halo,
280180
preprocessing=preprocessing,
281181
model=model_adapter,
282182
postprocessing=postprocessing,

bioimageio/core/resource_io/nodes.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pathlib
22
from dataclasses import dataclass
33
from pathlib import Path
4-
from typing import Callable, Dict, List, Union
4+
from typing import Callable, Dict, List, Tuple, Union
55

66
from marshmallow import missing
77
from marshmallow.utils import _Missing
@@ -107,12 +107,22 @@ class Postprocessing(Node, model_raw_nodes.Postprocessing):
107107

108108
@dataclass
109109
class InputTensor(Node, model_raw_nodes.InputTensor):
110-
pass
110+
axes: Tuple[str, ...] = missing
111+
112+
def __post_init__(self):
113+
super().__post_init__()
114+
# raw node has string with single letter axes names. Here we use a tuple of string here (like xr.DataArray).
115+
self.axes = tuple(self.axes)
111116

112117

113118
@dataclass
114119
class OutputTensor(Node, model_raw_nodes.OutputTensor):
115-
pass
120+
axes: Tuple[str, ...] = missing
121+
122+
def __post_init__(self):
123+
super().__post_init__()
124+
# raw node has string with single letter axes names. Here we use a tuple of string here (like xr.DataArray).
125+
self.axes = tuple(self.axes)
116126

117127

118128
@dataclass

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
],
2727
packages=find_namespace_packages(exclude=["tests"]), # Required
2828
install_requires=["bioimageio.spec", "imageio>=2.5", "numpy", "xarray"],
29+
include_package_data=True,
2930
extras_require={
3031
"test": ["pytest", "tox"],
3132
"dev": ["pre-commit"],

0 commit comments

Comments
 (0)