Skip to content

Commit f68150c

Browse files
committed
separate CombinedProcessing
1 parent 890e671 commit f68150c

File tree

1 file changed

+178
-0
lines changed

1 file changed

+178
-0
lines changed
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
import collections
2+
import warnings
3+
from collections import defaultdict
4+
from typing import Any, Dict, List, Literal, Optional, Sequence, Set, Tuple, Type
5+
6+
import xarray as xr
7+
from marshmallow import missing
8+
9+
from bioimageio.core.resource_io import nodes
10+
from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std
11+
from bioimageio.spec.model.raw_nodes import PostprocessingName, PreprocessingName
12+
from ._processing import (
13+
Binarize,
14+
Clip,
15+
Processing,
16+
ScaleLinear,
17+
ScaleMeanVariance,
18+
ScaleRange,
19+
Sigmoid,
20+
ZeroMeanUnitVariance,
21+
)
22+
23+
24+
KNOWN_PREPROCESSING: Dict[PreprocessingName, Type[Processing]] = {
25+
"binarize": Binarize,
26+
"clip": Clip,
27+
"scale_linear": ScaleLinear,
28+
"scale_range": ScaleRange,
29+
"sigmoid": Sigmoid,
30+
"zero_mean_unit_variance": ZeroMeanUnitVariance,
31+
}
32+
33+
KNOWN_POSTPROCESSING: Dict[PostprocessingName, Type[Processing]] = {
34+
"binarize": Binarize,
35+
"clip": Clip,
36+
"scale_linear": ScaleLinear,
37+
"scale_mean_variance": ScaleMeanVariance,
38+
"scale_range": ScaleRange,
39+
"sigmoid": Sigmoid,
40+
"zero_mean_unit_variance": ZeroMeanUnitVariance,
41+
}
42+
43+
44+
Scope = Literal["sample", "dataset"]
45+
SAMPLE: Literal["sample"] = "sample"
46+
DATASET: Literal["dataset"] = "dataset"
47+
SCOPES: Set[Scope] = {SAMPLE, DATASET}
48+
49+
50+
class CombinedProcessing:
51+
def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTensor]):
52+
self._prep = [
53+
KNOWN_PREPROCESSING[step.name](tensor_name=ipt.name, **step.kwargs)
54+
for ipt in inputs
55+
for step in ipt.preprocessing
56+
if ipt.preprocessing is not missing
57+
]
58+
self._post = [
59+
KNOWN_POSTPROCESSING.get(step.name)(tensor_name=out.name, **step.kwargs)
60+
for out in outputs
61+
for step in out.postprocessing
62+
if out.postprocessing is not missing
63+
]
64+
65+
self._req_input_stats = {s: self._collect_required_stats(self._prep, s) for s in SCOPES}
66+
self._req_output_stats = {s: self._collect_required_stats(self._post, s) for s in SCOPES}
67+
if any(self._req_output_stats[s] for s in SCOPES):
68+
raise NotImplementedError("computing statistics for output tensors not yet implemented")
69+
70+
assert not any(tn in self._req_output_stats for tn in self._req_input_stats)
71+
assert not any(tn in self._req_input_stats for tn in self._req_output_stats)
72+
self._computed_dataset_stats: Optional[Dict[str, Dict[Measure, Any]]] = None
73+
74+
self.input_tensor_names = [ipt.name for ipt in inputs]
75+
self.output_tensor_names = [out.name for out in outputs]
76+
77+
@property
78+
def required_input_dataset_statistics(self) -> Dict[str, Set[Measure]]:
79+
return self._req_input_stats[DATASET]
80+
81+
@property
82+
def required_output_dataset_statistics(self) -> Dict[str, Set[Measure]]:
83+
return self._req_output_stats[DATASET]
84+
85+
@property
86+
def computed_dataset_statistics(self) -> Dict[str, Dict[Measure, Any]]:
87+
if self._computed_dataset_stats is None:
88+
raise RuntimeError("Set computed dataset statistics first!")
89+
90+
return self._computed_dataset_stats
91+
92+
def apply_preprocessing(
93+
self, *input_tensors: xr.DataArray
94+
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
95+
assert len(input_tensors) == len(self.input_tensor_names)
96+
tensors = dict(zip(self.input_tensor_names, input_tensors))
97+
sample_stats = self.compute_sample_statistics(tensors, self._req_input_stats[SAMPLE])
98+
for proc in self._prep:
99+
proc.set_computed_sample_statistics(sample_stats)
100+
tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name])
101+
102+
return [tensors[tn] for tn in self.input_tensor_names], sample_stats
103+
104+
def apply_postprocessing(
105+
self, *output_tensors: xr.DataArray, input_sample_statistics: Dict[str, Dict[Measure, Any]]
106+
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
107+
assert len(output_tensors) == len(self.output_tensor_names)
108+
tensors = dict(zip(self.output_tensor_names, output_tensors))
109+
sample_stats = input_sample_statistics
110+
sample_stats.update(self.compute_sample_statistics(tensors, self._req_output_stats[SAMPLE]))
111+
for proc in self._prep:
112+
proc.set_computed_sample_statistics(sample_stats)
113+
tensors[proc.tensor_name] = proc.apply(tensors[proc.tensor_name])
114+
115+
return [tensors[tn] for tn in self.output_tensor_names], sample_stats
116+
117+
def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
118+
"""
119+
This method sets the externally computed dataset statistics.
120+
Which statistics are expected is specified by the `required_dataset_statistics` property.
121+
"""
122+
# always expect input tensor statistics
123+
for tensor_name, req_measures in self.required_input_dataset_statistics:
124+
comp_measures = computed.get(tensor_name, {})
125+
for req_measure in req_measures:
126+
if req_measure not in comp_measures:
127+
raise ValueError(f"Missing required measure {req_measure} for input tensor {tensor_name}")
128+
129+
# as output tensor statistics may initially not be available, we only warn about their absence
130+
output_statistics_missing = False
131+
for tensor_name, req_measures in self.required_output_dataset_statistics:
132+
comp_measures = computed.get(tensor_name, {})
133+
for req_measure in req_measures:
134+
if req_measure not in comp_measures:
135+
output_statistics_missing = True
136+
warnings.warn(f"Missing required measure {req_measure} for output tensor {tensor_name}")
137+
138+
self._computed_dataset_stats = computed
139+
140+
# set dataset statistics for each processing step
141+
for proc in self._prep:
142+
proc.set_computed_dataset_statistics(self.computed_dataset_statistics)
143+
144+
def compute_sample_statistics(
145+
self, tensors: Dict[str, xr.DataArray], measures: Dict[str, Set[Measure]]
146+
) -> Dict[str, Dict[Measure, Any]]:
147+
return {tname: self._compute_tensor_statistics(tensors[tname], ms) for tname, ms in measures.items()}
148+
149+
def _compute_tensor_statistics(self, tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
150+
ret = {}
151+
for measure in measures:
152+
if isinstance(measure, Mean):
153+
v = tensor.mean(dims=measure.axes)
154+
elif isinstance(measure, Std):
155+
v = tensor.std(dims=measure.axes)
156+
elif isinstance(measure, Percentile):
157+
v = tensor.quantile(measure.n / 100.0, dim=measure.axes)
158+
else:
159+
raise NotImplementedError(measure)
160+
161+
ret[measure] = v
162+
163+
return ret
164+
165+
@staticmethod
166+
def _collect_required_stats(proc: Sequence[Processing], scope: Literal["sample", "dataset"]):
167+
stats = defaultdict(set)
168+
for p in proc:
169+
if scope == SAMPLE:
170+
req = p.get_required_sample_statistics()
171+
elif scope == DATASET:
172+
req = p.get_required_dataset_statistics()
173+
else:
174+
raise ValueError(scope)
175+
for tn, ms in req.items():
176+
stats[tn].update(ms)
177+
178+
return dict(stats)

0 commit comments

Comments
 (0)