Skip to content

Commit 508b08f

Browse files
authored
Merge pull request #108 from bioimage-io/prepost_2
Processing with dataset statistics and reference_tensor
2 parents 9c73206 + f5dc588 commit 508b08f

File tree

10 files changed

+582
-299
lines changed

10 files changed

+582
-299
lines changed
Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import warnings
2+
from collections import defaultdict
3+
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Type
4+
5+
import xarray as xr
6+
7+
from bioimageio.core.resource_io import nodes
8+
from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std
9+
from bioimageio.spec.model.raw_nodes import PostprocessingName, PreprocessingName
10+
from ._processing import (
11+
Binarize,
12+
Clip,
13+
EnsureDtype,
14+
Processing,
15+
ScaleLinear,
16+
ScaleMeanVariance,
17+
ScaleRange,
18+
Sigmoid,
19+
ZeroMeanUnitVariance,
20+
)
21+
22+
try:
23+
from typing import Literal
24+
except ImportError:
25+
from typing_extensions import Literal # type: ignore
26+
27+
KNOWN_PREPROCESSING: Dict[PreprocessingName, Type[Processing]] = {
28+
"binarize": Binarize,
29+
"clip": Clip,
30+
"scale_linear": ScaleLinear,
31+
"scale_range": ScaleRange,
32+
"sigmoid": Sigmoid,
33+
"zero_mean_unit_variance": ZeroMeanUnitVariance,
34+
}
35+
36+
KNOWN_POSTPROCESSING: Dict[PostprocessingName, Type[Processing]] = {
37+
"binarize": Binarize,
38+
"clip": Clip,
39+
"scale_linear": ScaleLinear,
40+
"scale_mean_variance": ScaleMeanVariance,
41+
"scale_range": ScaleRange,
42+
"sigmoid": Sigmoid,
43+
"zero_mean_unit_variance": ZeroMeanUnitVariance,
44+
}
45+
46+
47+
Scope = Literal["sample", "dataset"]
48+
SAMPLE: Literal["sample"] = "sample"
49+
DATASET: Literal["dataset"] = "dataset"
50+
SCOPES: Set[Scope] = {SAMPLE, DATASET}
51+
52+
53+
class CombinedProcessing:
54+
def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTensor]):
55+
self._prep = [
56+
KNOWN_PREPROCESSING[step.name](tensor_name=ipt.name, **step.kwargs)
57+
for ipt in inputs
58+
for step in ipt.preprocessing or []
59+
]
60+
self._post = [
61+
KNOWN_POSTPROCESSING.get(step.name)(tensor_name=out.name, **step.kwargs)
62+
for out in outputs
63+
for step in out.postprocessing or []
64+
]
65+
66+
# There is a difference between pre-and-postprocessing:
67+
# Pre-processing always returns float32, because its output is consumed by the model.
68+
# Post-processing, however, should return the dtype that is specified in the model spec.
69+
# todo: cast dtype for inputs before preprocessing? or check dtype?
70+
for out in outputs:
71+
self._post.append(EnsureDtype(tensor_name=out.name, dtype=out.data_type))
72+
73+
self._req_input_stats = {s: self._collect_required_stats(self._prep, s) for s in SCOPES}
74+
self._req_output_stats = {s: self._collect_required_stats(self._post, s) for s in SCOPES}
75+
if any(self._req_output_stats[s] for s in SCOPES):
76+
raise NotImplementedError("computing statistics for output tensors not yet implemented")
77+
78+
self._computed_dataset_stats: Optional[Dict[str, Dict[Measure, Any]]] = None
79+
80+
self.input_tensor_names = [ipt.name for ipt in inputs]
81+
self.output_tensor_names = [out.name for out in outputs]
82+
assert not any(name in self.output_tensor_names for name in self.input_tensor_names)
83+
assert not any(name in self.input_tensor_names for name in self.output_tensor_names)
84+
85+
@property
86+
def required_input_dataset_statistics(self) -> Dict[str, Set[Measure]]:
87+
return self._req_input_stats[DATASET]
88+
89+
@property
90+
def required_output_dataset_statistics(self) -> Dict[str, Set[Measure]]:
91+
return self._req_output_stats[DATASET]
92+
93+
@property
94+
def computed_dataset_statistics(self) -> Dict[str, Dict[Measure, Any]]:
95+
return self._computed_dataset_stats
96+
97+
def apply_preprocessing(
98+
self, *input_tensors: xr.DataArray
99+
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
100+
assert len(input_tensors) == len(self.input_tensor_names)
101+
tensors = dict(zip(self.input_tensor_names, input_tensors))
102+
sample_stats = self.compute_sample_statistics(tensors, self._req_input_stats[SAMPLE])
103+
for proc in self._prep:
104+
proc.set_computed_sample_statistics(sample_stats)
105+
tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name])
106+
107+
return [tensors[tn] for tn in self.input_tensor_names], sample_stats
108+
109+
def apply_postprocessing(
110+
self, *output_tensors: xr.DataArray, input_sample_statistics: Dict[str, Dict[Measure, Any]]
111+
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
112+
assert len(output_tensors) == len(self.output_tensor_names)
113+
tensors = dict(zip(self.output_tensor_names, output_tensors))
114+
sample_stats = input_sample_statistics
115+
sample_stats.update(self.compute_sample_statistics(tensors, self._req_output_stats[SAMPLE]))
116+
for proc in self._post:
117+
proc.set_computed_sample_statistics(sample_stats)
118+
tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name])
119+
120+
return [tensors[tn] for tn in self.output_tensor_names], sample_stats
121+
122+
def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
123+
"""
124+
This method sets the externally computed dataset statistics.
125+
Which statistics are expected is specified by the `required_dataset_statistics` property.
126+
"""
127+
# always expect input tensor statistics
128+
for tensor_name, req_measures in self.required_input_dataset_statistics:
129+
comp_measures = computed.get(tensor_name, {})
130+
for req_measure in req_measures:
131+
if req_measure not in comp_measures:
132+
raise ValueError(f"Missing required measure {req_measure} for input tensor {tensor_name}")
133+
134+
# as output tensor statistics may initially not be available, we only warn about their absence
135+
output_statistics_missing = False
136+
for tensor_name, req_measures in self.required_output_dataset_statistics:
137+
comp_measures = computed.get(tensor_name, {})
138+
for req_measure in req_measures:
139+
if req_measure not in comp_measures:
140+
output_statistics_missing = True
141+
warnings.warn(f"Missing required measure {req_measure} for output tensor {tensor_name}")
142+
143+
self._computed_dataset_stats = computed
144+
145+
# set dataset statistics for each processing step
146+
for proc in self._prep:
147+
proc.set_computed_dataset_statistics(self.computed_dataset_statistics)
148+
149+
@classmethod
150+
def compute_sample_statistics(
151+
cls, tensors: Dict[str, xr.DataArray], measures: Dict[str, Set[Measure]]
152+
) -> Dict[str, Dict[Measure, Any]]:
153+
return {tname: cls._compute_tensor_statistics(tensors[tname], ms) for tname, ms in measures.items()}
154+
155+
@staticmethod
156+
def _compute_tensor_statistics(tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
157+
ret = {}
158+
for measure in measures:
159+
if isinstance(measure, Mean):
160+
v = tensor.mean(dim=measure.axes)
161+
elif isinstance(measure, Std):
162+
v = tensor.std(dim=measure.axes)
163+
elif isinstance(measure, Percentile):
164+
v = tensor.quantile(measure.n / 100.0, dim=measure.axes)
165+
else:
166+
raise NotImplementedError(measure)
167+
168+
ret[measure] = v
169+
170+
return ret
171+
172+
@staticmethod
173+
def _collect_required_stats(proc: Sequence[Processing], scope: Literal["sample", "dataset"]):
174+
stats = defaultdict(set)
175+
for p in proc:
176+
if scope == SAMPLE:
177+
req = p.get_required_sample_statistics()
178+
elif scope == DATASET:
179+
req = p.get_required_dataset_statistics()
180+
else:
181+
raise ValueError(scope)
182+
for tn, ms in req.items():
183+
stats[tn].update(ms)
184+
185+
return dict(stats)

bioimageio/core/prediction_pipeline/_postprocessing.py

Lines changed: 0 additions & 81 deletions
This file was deleted.

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 13 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
import abc
22
import math
33
from dataclasses import dataclass
4-
from typing import List, Optional, Sequence, Tuple, Union
4+
from typing import List, Optional, Tuple
55

66
import xarray as xr
7-
from bioimageio.core.resource_io import nodes
87
from marshmallow import missing
98

9+
from bioimageio.core.resource_io import nodes
10+
from ._combined_processing import CombinedProcessing
1011
from ._model_adapters import ModelAdapter, create_model_adapter
11-
from ._postprocessing import make_postprocessing
12-
13-
from ._preprocessing import make_preprocessing
14-
from ._types import Transform
15-
from ..resource_io.nodes import ImplicitOutputShape, InputTensor, Model, OutputTensor
12+
from ..resource_io.nodes import InputTensor, Model, OutputTensor
1613

1714

1815
@dataclass
@@ -64,20 +61,13 @@ def output_specs(self) -> List[OutputTensor]:
6461

6562
class _PredictionPipelineImpl(PredictionPipeline):
6663
def __init__(
67-
self,
68-
*,
69-
name: str,
70-
bioimageio_model: Model,
71-
preprocessing: Sequence[Transform],
72-
model: ModelAdapter,
73-
postprocessing: Sequence[Transform],
64+
self, *, name: str, bioimageio_model: Model, processing: CombinedProcessing, model: ModelAdapter
7465
) -> None:
7566
self._name = name
7667
self._input_specs = bioimageio_model.inputs
7768
self._output_specs = bioimageio_model.outputs
78-
self._preprocessing = preprocessing
69+
self._processing = processing
7970
self._model: ModelAdapter = model
80-
self._postprocessing = postprocessing
8171

8272
@property
8373
def name(self):
@@ -97,21 +87,17 @@ def predict(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
9787

9888
def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
9989
"""Apply preprocessing, run prediction and apply postprocessing."""
100-
assert len(self._preprocessing) == len(input_tensors)
101-
preprocessed = [fn(x) for fn, x in zip(self._preprocessing, input_tensors)]
90+
preprocessed, sample_stats = self._processing.apply_preprocessing(*input_tensors)
10291
prediction = self.predict(*preprocessed)
103-
assert len(self._postprocessing) == len(prediction)
104-
return [fn(x) for fn, x in zip(self._postprocessing, prediction)]
92+
return self._processing.apply_postprocessing(*prediction, input_sample_statistics=sample_stats)[0]
10593

10694
def preprocess(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
10795
"""Apply preprocessing."""
108-
assert len(self._preprocessing) == len(input_tensors)
109-
return [fn(x) for fn, x in zip(self._preprocessing, input_tensors)]
96+
return self._processing.apply_preprocessing(*input_tensors)[0]
11097

111-
def postprocess(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
98+
def postprocess(self, *input_tensors: xr.DataArray, input_sample_statistics) -> List[xr.DataArray]:
11299
"""Apply postprocessing."""
113-
assert len(self._postprocessing) == len(input_tensors)
114-
return [fn(x) for fn, x in zip(self._postprocessing, input_tensors)]
100+
return self._processing.apply_postprocessing(*input_tensors, input_sample_statistics=input_sample_statistics)[0]
115101

116102
def __call__(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
117103
return self.forward(*input_tensors)
@@ -157,27 +143,8 @@ def create_prediction_pipeline(
157143
bioimageio_model=bioimageio_model, devices=devices, weight_format=weight_format
158144
)
159145

160-
preprocessing: List[Transform] = []
161-
for ipt in bioimageio_model.inputs:
162-
try:
163-
input_shape = ipt.shape.min
164-
step = ipt.shape.step
165-
input_shape = enforce_min_shape(input_shape, step, ipt.axes)
166-
except AttributeError:
167-
input_shape = ipt.shape
168-
169-
preprocessing_spec = [] if ipt.preprocessing is missing else ipt.preprocessing.copy()
170-
preprocessing.append(make_preprocessing(preprocessing_spec))
171-
172-
postprocessing: List[Transform] = []
173-
for out in bioimageio_model.outputs:
174-
postprocessing_spec = [] if out.postprocessing is missing else out.postprocessing.copy()
175-
postprocessing.append(make_postprocessing(postprocessing_spec, out.data_type))
146+
processing = CombinedProcessing(bioimageio_model.inputs, bioimageio_model.outputs)
176147

177148
return _PredictionPipelineImpl(
178-
name=bioimageio_model.name,
179-
bioimageio_model=bioimageio_model,
180-
preprocessing=preprocessing,
181-
model=model_adapter,
182-
postprocessing=postprocessing,
149+
name=bioimageio_model.name, bioimageio_model=bioimageio_model, model=model_adapter, processing=processing
183150
)

0 commit comments

Comments
 (0)