Skip to content

Commit f01ee3d

Browse files
committed
fix predict_many and improve save_sample
1 parent d3787a5 commit f01ee3d

File tree

2 files changed

+62
-45
lines changed

2 files changed

+62
-45
lines changed

bioimageio/core/io.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -215,23 +215,33 @@ def save_tensor(path: Union[Path, str], tensor: Tensor) -> None:
215215
imwrite(path, data)
216216

217217

218-
def save_sample(path: Union[Path, str, PerMember[Path]], sample: Sample) -> None:
219-
"""save a sample to path
218+
def save_sample(
219+
path: Union[Path, str, PerMember[Union[Path, str]]], sample: Sample
220+
) -> None:
221+
"""Save a **sample** to a **path** pattern
222+
or all sample members in the **path** mapping.
220223
221-
If `path` is a pathlib.Path or a string it must contain `{member_id}` and may contain `{sample_id}`,
222-
which are resolved with the `sample` object.
223-
"""
224-
225-
if not isinstance(path, collections.abc.Mapping) and "{member_id}" not in str(path):
226-
raise ValueError(f"missing `{{member_id}}` in path {path}")
224+
If **path** is a pathlib.Path or a string and the **sample** has multiple members,
225+
**path** it must contain `{member_id}` (or `{input_id}` or `{output_id}`).
227226
228-
for m, t in sample.members.items():
229-
if isinstance(path, collections.abc.Mapping):
230-
p = path[m]
227+
(Each) **path** may contain `{sample_id}` to be formatted with the **sample** object.
228+
"""
229+
if not isinstance(path, collections.abc.Mapping):
230+
if len(sample.members) < 2 or any(
231+
m in str(path) for m in ("{member_id}", "{input_id}", "{output_id}")
232+
):
233+
path = {m: path for m in sample.members}
231234
else:
232-
p = Path(str(path).format(sample_id=sample.id, member_id=m))
235+
raise ValueError(
236+
f"path {path} must contain '{{member_id}}' for sample with multiple members {list(sample.members)}."
237+
)
233238

234-
save_tensor(p, t)
239+
for m, p in path.items():
240+
t = sample.members[m]
241+
p_formatted = Path(
242+
str(p).format(sample_id=sample.id, member_id=m, input_id=m, output_id=m)
243+
)
244+
save_tensor(p_formatted, t)
235245

236246

237247
class _SerializedDatasetStatsEntry(

bioimageio/core/prediction.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
2121
from .axis import AxisId
2222
from .common import BlocksizeParameter, MemberId, PerMember
23-
from .digest_spec import TensorSource, create_sample_for_model
23+
from .digest_spec import TensorSource, create_sample_for_model, get_member_id
2424
from .io import save_sample
2525
from .sample import Sample
2626

@@ -45,6 +45,8 @@ def predict(
4545
May be given as RDF source, model description or prediction pipeline.
4646
inputs: the input sample or the named input(s) for this model as a dictionary
4747
sample_id: the sample id.
48+
The **sample_id** is used to format **save_output_path**
49+
and to distinguish sample specific log messages.
4850
blocksize_parameter: (optional) Tile the input into blocks parametrized by
4951
**blocksize_parameter** according to any parametrized axis sizes defined
5052
by the **model**.
@@ -55,17 +57,15 @@ def predict(
5557
run prediction independent of the exact block shape.
5658
skip_preprocessing: Flag to skip the model's preprocessing.
5759
skip_postprocessing: Flag to skip the model's postprocessing.
58-
save_output_path: A path with `{member_id}` `{sample_id}` in it
59-
to save the output to.
60+
save_output_path: A path with to save the output to. M
61+
Must contain:
62+
- `{output_id}` (or `{member_id}`) if the model has multiple output tensors
63+
May contain:
64+
- `{sample_id}` to avoid overwriting recurrent calls
6065
"""
61-
if save_output_path is not None:
62-
if "{member_id}" not in str(save_output_path):
63-
raise ValueError(
64-
f"Missing `{{member_id}}` in save_output_path={save_output_path}"
65-
)
66-
6766
if isinstance(model, PredictionPipeline):
6867
pp = model
68+
model = pp.model_description
6969
else:
7070
if not isinstance(model, (v0_4.ModelDescr, v0_5.ModelDescr)):
7171
loaded = load_description(model)
@@ -75,6 +75,18 @@ def predict(
7575

7676
pp = create_prediction_pipeline(model)
7777

78+
if save_output_path is not None:
79+
if (
80+
"{output_id}" not in str(save_output_path)
81+
and "{member_id}" not in str(save_output_path)
82+
and len(model.outputs) > 1
83+
):
84+
raise ValueError(
85+
f"Missing `{{output_id}}` in save_output_path={save_output_path} to "
86+
+ "distinguish model outputs "
87+
+ str([get_member_id(d) for d in model.outputs])
88+
)
89+
7890
if isinstance(inputs, Sample):
7991
sample = inputs
8092
else:
@@ -120,7 +132,7 @@ def predict_many(
120132
model: Union[
121133
PermissiveFileSource, v0_4.ModelDescr, v0_5.ModelDescr, PredictionPipeline
122134
],
123-
inputs: Iterable[Union[TensorSource, PerMember[TensorSource]]],
135+
inputs: Union[Iterable[PerMember[TensorSource]], Iterable[TensorSource]],
124136
sample_id: str = "sample{i:03}",
125137
blocksize_parameter: Optional[
126138
Union[
@@ -135,31 +147,27 @@ def predict_many(
135147
"""Run prediction for a multiple sets of inputs with a bioimage.io model
136148
137149
Args:
138-
model: model to predict with.
150+
model: Model to predict with.
139151
May be given as RDF source, model description or prediction pipeline.
140152
inputs: An iterable of the named input(s) for this model as a dictionary.
141-
sample_id: the sample id.
153+
sample_id: The sample id.
142154
note: `{i}` will be formatted as the i-th sample.
143-
If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended.
144-
blocksize_parameter: (optional) tile the input into blocks parametrized by
145-
blocksize according to any parametrized axis sizes defined in the model RDF
146-
skip_preprocessing: flag to skip the model's preprocessing
147-
skip_postprocessing: flag to skip the model's postprocessing
148-
save_output_path: A path with `{member_id}` `{sample_id}` in it
149-
to save the output to.
155+
If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}`
156+
is appended.
157+
blocksize_parameter: (optional) Tile the input into blocks parametrized by
158+
blocksize according to any parametrized axis sizes defined in the model RDF.
159+
skip_preprocessing: Flag to skip the model's preprocessing.
160+
skip_postprocessing: Flag to skip the model's postprocessing.
161+
save_output_path: A path to save the output to.
162+
Must contain:
163+
- `{sample_id}` to differentiate predicted samples
164+
- `{output_id}` (or `{member_id}`) if the model has multiple outputs
150165
"""
151-
if save_output_path is not None:
152-
if "{member_id}" not in str(save_output_path):
153-
raise ValueError(
154-
f"Missing `{{member_id}}` in save_output_path={save_output_path}"
155-
)
156-
157-
if not isinstance(inputs, collections.abc.Mapping) and "{sample_id}" not in str(
158-
save_output_path
159-
):
160-
raise ValueError(
161-
f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
162-
)
166+
if save_output_path is not None and "{sample_id}" not in str(save_output_path):
167+
raise ValueError(
168+
f"Missing `{{sample_id}}` in save_output_path={save_output_path}"
169+
+ " to differentiate predicted samples."
170+
)
163171

164172
if isinstance(model, PredictionPipeline):
165173
pp = model
@@ -173,7 +181,6 @@ def predict_many(
173181
pp = create_prediction_pipeline(model)
174182

175183
if not isinstance(inputs, collections.abc.Mapping):
176-
sample_id = str(sample_id)
177184
if "{i}" not in sample_id and "{i:" not in sample_id:
178185
sample_id += "{i:03}"
179186

0 commit comments

Comments
 (0)