Skip to content

Commit 0853f96

Browse files
committed
add convenience functions for prediction
1 parent d8830d1 commit 0853f96

File tree

4 files changed

+234
-6
lines changed

4 files changed

+234
-6
lines changed

bioimageio/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@
2727
from .axis import AxisId as AxisId
2828
from .block_meta import BlockMeta as BlockMeta
2929
from .common import MemberId as MemberId
30+
from .prediction import predict as predict
31+
from .prediction import predict_many as predict_many
3032
from .sample import Sample as Sample
3133
from .tensor import Tensor as Tensor
3234
from .utils import VERSION

bioimageio/core/digest_spec.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import importlib.util
44
from itertools import chain
5+
from pathlib import Path
56
from typing import (
67
Any,
78
Callable,
@@ -16,6 +17,7 @@
1617
Union,
1718
)
1819

20+
import xarray as xr
1921
from numpy.typing import NDArray
2022
from typing_extensions import Unpack, assert_never
2123

@@ -33,6 +35,7 @@
3335
from .axis import AxisId, AxisInfo, PerAxis
3436
from .block_meta import split_multiple_shapes_into_blocks
3537
from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
38+
from .io import get_tensor
3639
from .sample import (
3740
LinearSampleAxisTransform,
3841
Sample,
@@ -334,7 +337,9 @@ def create_sample_for_model(
334337
*,
335338
stat: Optional[Stat] = None,
336339
sample_id: SampleId = None,
337-
inputs: Optional[PerMember[NDArray[Any]]] = None, # TODO: make non-optional
340+
inputs: Optional[
341+
PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]
342+
] = None, # TODO: make non-optional
338343
**kwargs: NDArray[Any], # TODO: deprecate in favor of `inputs`
339344
) -> Sample:
340345
"""Create a sample from a single set of input(s) for a specific bioimage.io model
@@ -359,7 +364,7 @@ def create_sample_for_model(
359364

360365
return Sample(
361366
members={
362-
m: Tensor.from_numpy(inputs[m], dims=get_axes_infos(ipt))
367+
m: get_tensor(inputs[m], ipt)
363368
for m, ipt in model_inputs.items()
364369
if m in inputs
365370
},

bioimageio/core/io.py

Lines changed: 51 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
from pathlib import Path
2-
from typing import Any, Dict, Optional, Sequence
2+
from typing import Any, Dict, Optional, Sequence, Union
33

44
import imageio
5+
import numpy as np
6+
import xarray as xr
57
from loguru import logger
68
from numpy.typing import NDArray
9+
from typing_extensions import assert_never
710

8-
from bioimageio.spec.model import AnyModelDescr
9-
from bioimageio.spec.utils import load_array
11+
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
12+
from bioimageio.spec.utils import load_array, save_array
1013

1114
from .axis import Axis, AxisLike
1215
from .common import MemberId, PerMember, SampleId
@@ -26,6 +29,7 @@ def load_image(path: Path, is_volume: bool) -> NDArray[Any]:
2629

2730

2831
def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor:
32+
# TODO: load axis meta data
2933
array = load_image(
3034
path,
3135
is_volume=(
@@ -36,6 +40,50 @@ def load_tensor(path: Path, axes: Optional[Sequence[AxisLike]] = None) -> Tensor
3640
return Tensor.from_numpy(array, dims=axes)
3741

3842

43+
def get_tensor(
44+
src: Union[Tensor, xr.DataArray, NDArray[Any], Path],
45+
ipt: Union[v0_4.InputTensorDescr, v0_5.InputTensorDescr],
46+
):
47+
"""helper to cast/load various tensor sources"""
48+
49+
if isinstance(src, Tensor):
50+
return src
51+
52+
if isinstance(src, xr.DataArray):
53+
return Tensor.from_xarray(src)
54+
55+
if isinstance(src, np.ndarray):
56+
return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
57+
58+
if isinstance(src, Path):
59+
return load_tensor(src, axes=get_axes_infos(ipt))
60+
61+
assert_never(src)
62+
63+
64+
def save_tensor(path: Path, tensor: Tensor) -> None:
65+
# TODO: save axis meta data
66+
data: NDArray[Any] = tensor.data.to_numpy()
67+
if path.suffix == ".npy":
68+
save_array(path, data)
69+
else:
70+
imageio.volwrite(path, data)
71+
72+
73+
def save_sample(path: Union[Path, str], sample: Sample) -> None:
74+
"""save a sample to path
75+
76+
`path` must contain `{member_id}` and may contain `{sample_id}`,
77+
which are resolved with the `sample` object.
78+
"""
79+
path = str(path).format(sample_id=sample.id)
80+
if "{member_id}" not in path:
81+
raise ValueError(f"missing `{{member_id}}` in path {path}")
82+
83+
for m, t in sample.members.items():
84+
save_tensor(Path(path.format(member_id=m)), t)
85+
86+
3987
def load_sample_for_model(
4088
*,
4189
model: AnyModelDescr,

bioimageio/core/prediction.py

Lines changed: 174 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,177 @@
44
e..g load samples with core.io.load_sample_for_model()
55
"""
66

7-
# TODO: add convenience functions for predictions
7+
import collections
8+
from pathlib import Path
9+
from typing import (
10+
Any,
11+
Generator,
12+
Hashable,
13+
Iterable,
14+
Iterator,
15+
List,
16+
Mapping,
17+
Optional,
18+
Tuple,
19+
Union,
20+
)
21+
22+
import xarray as xr
23+
from numpy.typing import NDArray
24+
from tqdm import tqdm
25+
26+
from bioimageio.core.axis import AxisId
27+
from bioimageio.core.io import save_sample
28+
from bioimageio.spec import load_description
29+
from bioimageio.spec.common import PermissiveFileSource
30+
from bioimageio.spec.model import v0_4, v0_5
31+
32+
from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
33+
from .common import MemberId, PerMember
34+
from .digest_spec import create_sample_for_model
35+
from .sample import Sample
36+
from .tensor import Tensor
37+
38+
39+
def predict(
40+
*,
41+
model: Union[
42+
PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
43+
],
44+
inputs: PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]],
45+
sample_id: Hashable = "sample",
46+
blocksize_parameter: Optional[
47+
Union[
48+
v0_5.ParameterizedSize.N,
49+
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
50+
]
51+
] = None,
52+
skip_preprocessing: bool = False,
53+
skip_postprocessing: bool = False,
54+
save_output_path: Optional[Union[Path, str]] = None,
55+
) -> Sample:
56+
"""Run prediction for a single set of input(s) with a bioimage.io model
57+
58+
Args:
59+
model: model to predict with.
60+
May be given as RDF source, model description or prediction pipeline.
61+
inputs: the named input(s) for this model as a dictionary
62+
sample_id: the sample id.
63+
blocksize_parameter: (optional) tile the input into blocks parametrized by
64+
blocksize according to any parametrized axis sizes defined in the model RDF
65+
skip_preprocessing: flag to skip the model's preprocessing
66+
skip_postprocessing: flag to skip the model's postprocessing
67+
save_output_path: A path with `{member_id}` `{sample_id}` in it
68+
to save the output to.
69+
"""
70+
if save_output_path is not None:
71+
if "{member_id}" not in str(save_output_path):
72+
raise ValueError(
73+
f"Missing `{{member_id}}` in save_output_path={save_output_path}"
74+
)
75+
76+
if isinstance(model, PredictionPipeline):
77+
pp = model
78+
else:
79+
if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
80+
loaded = load_description(model)
81+
if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
82+
raise ValueError(f"expected model description, but got {loaded}")
83+
model = loaded
84+
85+
pp = create_prediction_pipeline(model)
86+
87+
sample = create_sample_for_model(
88+
pp.model_description, inputs=inputs, sample_id=sample_id
89+
)
90+
91+
if blocksize_parameter is None:
92+
output = pp.predict_sample_without_blocking(
93+
sample,
94+
skip_preprocessing=skip_preprocessing,
95+
skip_postprocessing=skip_postprocessing,
96+
)
97+
else:
98+
output = pp.predict_sample_with_blocking(
99+
sample,
100+
skip_preprocessing=skip_preprocessing,
101+
skip_postprocessing=skip_postprocessing,
102+
ns=blocksize_parameter,
103+
)
104+
if save_output_path:
105+
save_sample(save_output_path, output)
106+
107+
return output
108+
109+
110+
def predict_many(
111+
*,
112+
model: Union[
113+
PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
114+
],
115+
inputs: Iterable[PerMember[Union[Tensor, xr.DataArray, NDArray[Any], Path]]],
116+
sample_id: str = "sample{i:03}",
117+
blocksize_parameter: Optional[
118+
Union[
119+
v0_5.ParameterizedSize.N,
120+
Mapping[Tuple[MemberId, AxisId], v0_5.ParameterizedSize.N],
121+
]
122+
] = None,
123+
skip_preprocessing: bool = False,
124+
skip_postprocessing: bool = False,
125+
save_output_path: Optional[Union[Path, str]] = None,
126+
) -> Iterator[Sample]:
127+
"""Run prediction for a multiple sets of inputs with a bioimage.io model
128+
129+
Args:
130+
model: model to predict with.
131+
May be given as RDF source, model description or prediction pipeline.
132+
inputs: An iterable of the named input(s) for this model as a dictionary.
133+
sample_id: the sample id.
134+
note: `{i}` will be formatted as the i-th sample.
135+
If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended.
136+
blocksize_parameter: (optional) tile the input into blocks parametrized by
137+
blocksize according to any parametrized axis sizes defined in the model RDF
138+
skip_preprocessing: flag to skip the model's preprocessing
139+
skip_postprocessing: flag to skip the model's postprocessing
140+
save_output_path: A path with `{member_id}` `{sample_id}` in it
141+
to save the output to.
142+
"""
143+
if save_output_path is not None:
144+
if "{member_id}" not in str(save_output_path):
145+
raise ValueError(
146+
f"Missing `{{member_id}}` in save_output_path={save_output_path}"
147+
)
148+
149+
if not isinstance(inputs, collections.Mapping) and "{sample_id}" not in str(
150+
save_output_path
151+
):
152+
raise ValueError(
153+
f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
154+
)
155+
156+
if isinstance(model, PredictionPipeline):
157+
pp = model
158+
else:
159+
if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
160+
loaded = load_description(model)
161+
if not isinstance(loaded, (v0_4.ModelDescr, v0_5.ModelDescr)):
162+
raise ValueError(f"expected model description, but got {loaded}")
163+
model = loaded
164+
165+
pp = create_prediction_pipeline(model)
166+
167+
if not isinstance(inputs, collections.Mapping):
168+
sample_id = str(sample_id)
169+
if "{i}" not in sample_id and "{i:" not in sample_id:
170+
sample_id += "{i:03}"
171+
for i, ipts in tqdm(enumerate(inputs)):
172+
yield predict(
173+
model=pp,
174+
inputs=ipts,
175+
sample_id=sample_id.format(i=i),
176+
blocksize_parameter=blocksize_parameter,
177+
skip_preprocessing=skip_preprocessing,
178+
skip_postprocessing=skip_postprocessing,
179+
save_output_path=save_output_path,
180+
)

0 commit comments

Comments
 (0)