2020from ._prediction_pipeline import PredictionPipeline , create_prediction_pipeline
2121from .axis import AxisId
2222from .common import BlocksizeParameter , MemberId , PerMember
23- from .digest_spec import TensorSource , create_sample_for_model
23+ from .digest_spec import TensorSource , create_sample_for_model , get_member_id
2424from .io import save_sample
2525from .sample import Sample
2626
@@ -45,6 +45,8 @@ def predict(
4545 May be given as RDF source, model description or prediction pipeline.
4646 inputs: the input sample or the named input(s) for this model as a dictionary
4747 sample_id: the sample id.
48+ The **sample_id** is used to format **save_output_path**
49+ and to distinguish sample specific log messages.
4850 blocksize_parameter: (optional) Tile the input into blocks parametrized by
4951 **blocksize_parameter** according to any parametrized axis sizes defined
5052 by the **model**.
@@ -55,17 +57,15 @@ def predict(
5557 run prediction independent of the exact block shape.
5658 skip_preprocessing: Flag to skip the model's preprocessing.
5759 skip_postprocessing: Flag to skip the model's postprocessing.
58- save_output_path: A path with `{member_id}` `{sample_id}` in it
59- to save the output to.
60+ save_output_path: A path with to save the output to. M
61+ Must contain:
62+ - `{output_id}` (or `{member_id}`) if the model has multiple output tensors
63+ May contain:
64+ - `{sample_id}` to avoid overwriting recurrent calls
6065 """
61- if save_output_path is not None :
62- if "{member_id}" not in str (save_output_path ):
63- raise ValueError (
64- f"Missing `{{member_id}}` in save_output_path={ save_output_path } "
65- )
66-
6766 if isinstance (model , PredictionPipeline ):
6867 pp = model
68+ model = pp .model_description
6969 else :
7070 if not isinstance (model , (v0_4 .ModelDescr , v0_5 .ModelDescr )):
7171 loaded = load_description (model )
@@ -75,6 +75,18 @@ def predict(
7575
7676 pp = create_prediction_pipeline (model )
7777
78+ if save_output_path is not None :
79+ if (
80+ "{output_id}" not in str (save_output_path )
81+ and "{member_id}" not in str (save_output_path )
82+ and len (model .outputs ) > 1
83+ ):
84+ raise ValueError (
85+ f"Missing `{{output_id}}` in save_output_path={ save_output_path } to "
86+ + "distinguish model outputs "
87+ + str ([get_member_id (d ) for d in model .outputs ])
88+ )
89+
7890 if isinstance (inputs , Sample ):
7991 sample = inputs
8092 else :
@@ -120,7 +132,7 @@ def predict_many(
120132 model : Union [
121133 PermissiveFileSource , v0_4 .ModelDescr , v0_5 .ModelDescr , PredictionPipeline
122134 ],
123- inputs : Iterable [Union [TensorSource , PerMember [TensorSource ] ]],
135+ inputs : Union [ Iterable [PerMember [TensorSource ]], Iterable [TensorSource ]],
124136 sample_id : str = "sample{i:03}" ,
125137 blocksize_parameter : Optional [
126138 Union [
@@ -135,31 +147,27 @@ def predict_many(
135147 """Run prediction for a multiple sets of inputs with a bioimage.io model
136148
137149 Args:
138- model: model to predict with.
150+ model: Model to predict with.
139151 May be given as RDF source, model description or prediction pipeline.
140152 inputs: An iterable of the named input(s) for this model as a dictionary.
141- sample_id: the sample id.
153+ sample_id: The sample id.
142154 note: `{i}` will be formatted as the i-th sample.
143- If `{i}` (or `{i:`) is not present and `inputs` is an iterable `{i:03}` is appended.
144- blocksize_parameter: (optional) tile the input into blocks parametrized by
145- blocksize according to any parametrized axis sizes defined in the model RDF
146- skip_preprocessing: flag to skip the model's preprocessing
147- skip_postprocessing: flag to skip the model's postprocessing
148- save_output_path: A path with `{member_id}` `{sample_id}` in it
149- to save the output to.
155+ If `{i}` (or `{i:`) is not present and `inputs` is not an iterable `{i:03}`
156+ is appended.
157+ blocksize_parameter: (optional) Tile the input into blocks parametrized by
158+ blocksize according to any parametrized axis sizes defined in the model RDF.
159+ skip_preprocessing: Flag to skip the model's preprocessing.
160+ skip_postprocessing: Flag to skip the model's postprocessing.
161+ save_output_path: A path to save the output to.
162+ Must contain:
163+ - `{sample_id}` to differentiate predicted samples
164+ - `{output_id}` (or `{member_id}`) if the model has multiple outputs
150165 """
151- if save_output_path is not None :
152- if "{member_id}" not in str (save_output_path ):
153- raise ValueError (
154- f"Missing `{{member_id}}` in save_output_path={ save_output_path } "
155- )
156-
157- if not isinstance (inputs , collections .abc .Mapping ) and "{sample_id}" not in str (
158- save_output_path
159- ):
160- raise ValueError (
161- f"Missing `{{sample_id}}` in save_output_path={ save_output_path } "
162- )
166+ if save_output_path is not None and "{sample_id}" not in str (save_output_path ):
167+ raise ValueError (
168+ f"Missing `{{sample_id}}` in save_output_path={ save_output_path } "
169+ + " to differentiate predicted samples."
170+ )
163171
164172 if isinstance (model , PredictionPipeline ):
165173 pp = model
@@ -173,7 +181,6 @@ def predict_many(
173181 pp = create_prediction_pipeline (model )
174182
175183 if not isinstance (inputs , collections .abc .Mapping ):
176- sample_id = str (sample_id )
177184 if "{i}" not in sample_id and "{i:" not in sample_id :
178185 sample_id += "{i:03}"
179186
0 commit comments