Skip to content

Commit 256399b

Browse files
committed
reduce required input to CombinedProcessing constructor
1 parent 29005ca commit 256399b

File tree

3 files changed

+51
-36
lines changed

3 files changed

+51
-36
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 48 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List, Optional, Sequence, Union
1+
import dataclasses
2+
from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence, TypedDict, Union
23

34
from bioimageio.core.resource_io import nodes
4-
from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing
5+
from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing, TensorName
56
from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample
67

78
try:
@@ -10,44 +11,58 @@
1011
from typing_extensions import Literal # type: ignore
1112

1213

14+
@dataclasses.dataclass
15+
class TensorProcessingInfo:
16+
processing_steps: Union[List[nodes.Preprocessing], List[nodes.Postprocessing]]
17+
data_type_before: Optional[str] = None
18+
data_type_after: Optional[str] = None
19+
20+
1321
class CombinedProcessing:
14-
def __init__(self, tensor_specs: Union[List[nodes.InputTensor], List[nodes.OutputTensor]]):
15-
PRE: Literal["pre"] = "pre"
16-
POST: Literal["post"] = "post"
17-
proc_prefix: Optional[Literal["pre", "post"]] = None
22+
def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]):
1823
self._procs = []
19-
for t in tensor_specs:
20-
if isinstance(t, nodes.InputTensor):
21-
steps = t.preprocessing or []
22-
if proc_prefix is not None and proc_prefix != PRE:
23-
raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}")
24-
25-
proc_prefix = PRE
26-
elif isinstance(t, nodes.OutputTensor):
27-
steps = t.postprocessing or []
28-
if proc_prefix is not None and proc_prefix != POST:
29-
raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}")
30-
31-
proc_prefix = POST
32-
else:
33-
raise NotImplementedError(t)
24+
known = dict(KNOWN_PROCESSING["pre"])
25+
known.update(KNOWN_PROCESSING["post"])
26+
27+
# ensure all tensors have correct data type before any processing
28+
for tensor_name, info in combine_tensors.items():
29+
if info.data_type_before is not None:
30+
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.data_type_before))
3431

35-
for step in steps:
36-
self._procs.append(KNOWN_PROCESSING[proc_prefix][step.name](tensor_name=t.name, **step.kwargs))
32+
for tensor_name, info in combine_tensors.items():
33+
for step in info.processing_steps:
34+
self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs))
3735

38-
# There is a difference between pre-and postprocessing:
39-
# Preprocessing always returns float32, because its output is consumed by the model.
40-
# Postprocessing, however, should return the dtype that is specified in the model spec.
41-
# todo: cast dtype for inputs before preprocessing? or check dtype?
42-
if proc_prefix == POST:
43-
for t in tensor_specs:
44-
self._procs.append(EnsureDtype(tensor_name=t.name, dtype=t.data_type))
36+
# ensure tensor has correct data type right after its processing
37+
if info.data_type_after is not None:
38+
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.data_type_after))
4539

4640
self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs)
47-
if proc_prefix == POST and self.required_measures[PER_DATASET]:
48-
raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented")
41+
self.tensor_names = list(combine_tensors)
42+
43+
@classmethod
44+
def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.OutputTensor]]):
45+
combine_tensors = {}
46+
for ts in tensor_specs:
47+
# There is a difference between pre-and postprocessing:
48+
# Preprocessing always returns float32, because its output is consumed by the model.
49+
# Postprocessing, however, should return the dtype that is specified in the model spec.
50+
# todo: cast dtype for inputs before preprocessing? or check dtype?
51+
assert ts.name not in combine_tensors
52+
if isinstance(ts, nodes.InputTensor):
53+
# todo: move preprocessing ensure_dtype here as data_type_after
54+
combine_tensors[ts.name] = TensorProcessingInfo(ts.preprocessing)
55+
elif isinstance(ts, nodes.OutputTensor):
56+
combine_tensors[ts.name] = TensorProcessingInfo(ts.postprocessing, None, ts.data_type)
57+
else:
58+
raise NotImplementedError(type(ts))
59+
60+
inst = cls(combine_tensors)
61+
for ts in tensor_specs:
62+
if isinstance(ts, nodes.OutputTensor) and ts.name in inst.required_measures[PER_DATASET]:
63+
raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented")
4964

50-
self.tensor_names = [t.name for t in tensor_specs]
65+
return inst
5166

5267
def apply(self, sample: Sample, computed_measures: ComputedMeasures) -> None:
5368
for proc in self._procs:

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def create_prediction_pipeline(
213213
ipts = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs]
214214
outs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs]
215215

216-
preprocessing = CombinedProcessing(ipts)
216+
preprocessing = CombinedProcessing.from_tensor_specs(ipts)
217217

218218
def sample_dataset():
219219
for tensors in dataset_for_initial_statistics:
@@ -225,7 +225,7 @@ def sample_dataset():
225225
update_dataset_stats_after_n_samples=update_dataset_stats_after_n_samples,
226226
update_dataset_stats_for_n_samples=update_dataset_stats_for_n_samples,
227227
)
228-
postprocessing = CombinedProcessing(outs)
228+
postprocessing = CombinedProcessing.from_tensor_specs(outs)
229229
out_stats = StatsState(
230230
postprocessing.required_measures,
231231
dataset=tuple(),

tests/prediction_pipeline/test_combined_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_postprocessing_dtype():
2525
postprocessing=[nodes.Postprocessing("binarize", dict(threshold=threshold))],
2626
)
2727
]
28-
com_proc = CombinedProcessing(outputs)
28+
com_proc = CombinedProcessing.from_tensor_specs(outputs)
2929

3030
sample = {"out1": data}
3131
com_proc.apply(sample, {})

0 commit comments

Comments
 (0)