Skip to content

Commit 890e671

Browse files
committed
improve Processing
1 parent 561ee7d commit 890e671

File tree

2 files changed

+120
-116
lines changed

2 files changed

+120
-116
lines changed
Lines changed: 119 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,10 @@
1-
from dataclasses import dataclass
2-
from typing import (
3-
Any,
4-
Dict,
5-
List,
6-
Literal,
7-
Optional,
8-
Sequence,
9-
Set,
10-
Type,
11-
Union,
12-
get_args,
13-
)
1+
from dataclasses import dataclass, field
2+
from typing import Any, Dict, Literal, Optional, Sequence, Set, get_args
143

154
import numpy as np
165
import xarray as xr
176

18-
from bioimageio.core.resource_io import nodes
197
from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std
20-
from bioimageio.spec.model.v0_3.raw_nodes import PreprocessingName
218

229

2310
def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
@@ -29,8 +16,11 @@ def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
2916

3017
@dataclass
3118
class Processing:
32-
apply_to: str
33-
computed_statistics: Dict[str, Dict[Measure, Any]]
19+
"""base class for all Pre- and Postprocessing transformations"""
20+
21+
tensor_name: str
22+
computed_dataset_statistics: Dict[str, Dict[Measure, Any]] = field(init=False)
23+
computed_sample_statistics: Dict[str, Dict[Measure, Any]] = field(init=False)
3424

3525
def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
3626
"""
@@ -39,30 +29,48 @@ def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
3929
"""
4030
return {}
4131

42-
def set_computed_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
32+
def get_required_sample_statistics(self) -> Dict[str, Set[Measure]]:
33+
"""
34+
Specifies which sample measures are required from what tensor.
35+
Returns: sample measures required to apply this processing indexed by <tensor_name>.
36+
"""
37+
38+
def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
4339
"""helper to set computed statistics and check if they match the requirements"""
4440
for tensor_name, req_measures in self.get_required_dataset_statistics():
4541
comp_measures = computed.get(tensor_name, {})
4642
for req_measure in req_measures:
4743
if req_measure not in comp_measures:
48-
raise ValueError("Missing required measure {req_measure} for {tensor_name}")
49-
self.computed_statistics = computed
44+
raise ValueError(f"Missing required measure {req_measure} for {tensor_name}")
45+
self.computed_dataset_statistics = computed
46+
47+
def set_computed_sample_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
48+
"""helper to set computed statistics and check if they match the requirements"""
49+
for tensor_name, req_measures in self.get_required_sample_statistics():
50+
comp_measures = computed.get(tensor_name, {})
51+
for req_measure in req_measures:
52+
if req_measure not in comp_measures:
53+
raise ValueError(f"Missing required measure {req_measure} for {tensor_name}")
54+
self.computed_sample_statistics = computed
5055

51-
def get_computed_statistics(self, tensor_name: str, measure: Measure):
52-
"""helper to unpack self.computed_statistics"""
53-
ret = self.computed_statistics.get(tensor_name, {}).get(measure)
56+
def get_computed_dataset_statistics(self, tensor_name: str, measure: Measure):
57+
"""helper to unpack self.computed_dataset_statistics"""
58+
ret = self.computed_dataset_statistics.get(tensor_name, {}).get(measure)
5459
if ret is None:
5560
raise RuntimeError(f"Missing computed {measure} for {tensor_name} dataset.")
5661

5762
return ret
5863

59-
def apply(self, **tensors: xr.DataArray) -> Dict[str, xr.DataArray]:
60-
"""apply processing to named tensors; call 'apply_simple' as default"""
61-
tensors[self.apply_to] = self.apply_simple(tensors[self.apply_to])
62-
return tensors
64+
def get_computed_sample_statistics(self, tensor_name: str, measure: Measure):
65+
"""helper to unpack self.computed_sample_statistics"""
66+
ret = self.computed_sample_statistics.get(tensor_name, {}).get(measure)
67+
if ret is None:
68+
raise RuntimeError(f"Missing computed {measure} for {tensor_name} sample.")
6369

64-
def apply_simple(self, tensor: xr.DataArray) -> xr.DataArray:
65-
"""apply processing to single tensor"""
70+
return ret
71+
72+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
73+
"""apply processing to named tensors"""
6674
raise NotImplementedError
6775

6876
def __post_init__(self):
@@ -72,6 +80,28 @@ def __post_init__(self):
7280
raise NotImplementedError(f"Unsupported mode {self.mode} for {self.__class__.__name__}: {self.mode}")
7381

7482

83+
#
84+
# Pre- and Postprocessing implementations
85+
#
86+
87+
88+
@dataclass
89+
class Binarize(Processing):
90+
threshold: float
91+
92+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
93+
return ensure_dtype(tensor > self.threshold, dtype="float32")
94+
95+
96+
@dataclass
97+
class Clip(Processing):
98+
min: float
99+
max: float
100+
101+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
102+
return ensure_dtype(tensor.clip(min=self.min, max=self.max), dtype="float32")
103+
104+
75105
@dataclass
76106
class ScaleLinear(Processing):
77107
"""scale the tensor with a fixed multiplicative and additive factor"""
@@ -93,48 +123,8 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
93123

94124

95125
@dataclass
96-
class ZeroMeanUnitVariance(Processing):
97-
mode: Literal["fixed", "per_sample", "per_dataset"] = "per_sample"
98-
mean: Optional[float] = None
99-
std: Optional[float] = None
100-
axes: Optional[Sequence[str]] = None
101-
eps: float = 1.0e-6
102-
103-
def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
104-
if self.mode == "per_dataset":
105-
return {self.apply_to: {Mean(), Std()}}
106-
else:
107-
return {}
108-
109-
def apply(self, **tensors: xr.DataArray) -> Dict[str, xr.DataArray]:
110-
tensor = tensors[self.apply_to]
111-
if self.mode == "fixed":
112-
assert self.mean is not None and self.std is not None
113-
mean, std = self.mean, self.std
114-
elif self.mode == "per_sample":
115-
if self.axes:
116-
axes = tuple(self.axes)
117-
mean, std = tensor.mean(axes), tensor.std(axes)
118-
else:
119-
mean, std = tensor.mean(), tensor.std()
120-
elif self.mode == "per_dataset":
121-
mean = self.get_computed_statistics(self.apply_to, "mean")
122-
std = self.get_computed_statistics(self.apply_to, "std")
123-
else:
124-
raise ValueError(self.mode)
125-
126-
tensor = (tensor - mean) / (std + self.eps)
127-
tensors[self.apply_to] = ensure_dtype(tensor, dtype="float32")
128-
129-
return tensors
130-
131-
132-
@dataclass
133-
class Binarize(Processing):
134-
threshold: float
135-
136-
def apply_simple(self, tensor: xr.DataArray) -> xr.DataArray:
137-
return ensure_dtype(tensor > self.threshold, dtype="float32")
126+
class ScaleMeanVariance(Processing):
127+
...
138128

139129

140130
@dataclass
@@ -150,39 +140,37 @@ def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
150140
return {}
151141
elif self.mode == "per_dataset":
152142
measures = {Percentile(self.min_percentile), Percentile(self.max_percentile)}
153-
return {self.reference_tensor or self.apply_to: measures}
143+
return {self.reference_tensor or self.tensor_name: measures}
154144
else:
155145
raise ValueError(self.mode)
156146

157-
def apply(self, **tensors: xr.DataArray) -> Dict[str, xr.DataArray]:
158-
ref_name = self.reference_tensor or self.apply_to
147+
def get_required_sample_statistics(self) -> Dict[str, Set[Measure]]:
159148
if self.mode == "per_sample":
160-
ref_tensor = tensors[ref_name]
161-
if self.axes:
162-
axes = tuple(self.axes)
163-
else:
164-
axes = None
165-
166-
v_lower = ref_tensor.quantile(self.min_percentile / 100.0, dim=axes)
167-
v_upper = ref_tensor.quantile(self.max_percentile / 100.0, dim=axes)
149+
measures = {Percentile(self.min_percentile), Percentile(self.max_percentile)}
150+
return {self.reference_tensor or self.tensor_name: measures}
168151
elif self.mode == "per_dataset":
169-
v_lower = self.get_computed_statistics(ref_name, Percentile(self.min_percentile))
170-
v_upper = self.get_computed_statistics(ref_name, Percentile(self.max_percentile))
152+
return {}
171153
else:
172154
raise ValueError(self.mode)
173155

174-
tensors[self.apply_to] = ensure_dtype((tensors[self.apply_to] - v_lower) / v_upper, dtype="float32")
175-
return tensors
156+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
157+
ref_name = self.reference_tensor or self.tensor_name
158+
if self.axes:
159+
axes = tuple(self.axes)
160+
else:
161+
axes = None
176162

163+
if self.mode == "per_sample":
164+
get_stat = self.get_computed_sample_statistics
165+
elif self.mode == "per_dataset":
166+
get_stat = self.get_computed_dataset_statistics
167+
else:
168+
raise ValueError(self.mode)
177169

178-
# todo: continue here....
179-
@dataclass
180-
class Clip(Processing):
181-
min: float
182-
max: float
170+
v_lower = get_stat(ref_name, Percentile(self.min_percentile, axes=axes))
171+
v_upper = get_stat(ref_name, Percentile(self.max_percentile, axes=axes))
183172

184-
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
185-
return ensure_dtype(tensor.clip(min=self.min, max=self.max), dtype="float32")
173+
return ensure_dtype((tensor - v_lower) / v_upper, dtype="float32")
186174

187175

188176
@dataclass
@@ -191,26 +179,41 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
191179
return 1.0 / (1.0 + xr.ufuncs.exp(-tensor))
192180

193181

194-
KNOWN_PREPROCESSING: Dict[PreprocessingName, Type[Processing]] = {
195-
"scale_linear": ScaleLinear,
196-
"zero_mean_unit_variance": ZeroMeanUnitVariance,
197-
"binarize": Binarize,
198-
"clip": Clip,
199-
"scale_range": ScaleRange,
200-
"sigmoid": Sigmoid,
201-
}
202-
203-
204-
class CombinedProcessing:
205-
def __init__(
206-
self,
207-
processing_spec: Union[List[nodes.Preprocessing], List[nodes.Postprocessing]],
208-
input_tensor_names: Sequence[str],
209-
output_tensor_names: Sequence[str] = tuple(),
210-
):
211-
prep = all(isinstance(ps, nodes.Preprocessing) for ps in processing_spec)
212-
assert prep or all(isinstance(ps, nodes.Postprocessing) for ps in processing_spec)
213-
214-
self.tensor_names = input_tensor_names if prep else output_tensor_names
215-
self.tensor_names = input_tensor_names if prep else output_tensor_names
216-
self.procs = [KNOWN_PREPROCESSING.get(step.name)(**step.kwargs) for step in processing_spec]
182+
@dataclass
183+
class ZeroMeanUnitVariance(Processing):
184+
mode: Literal["fixed", "per_sample", "per_dataset"] = "per_sample"
185+
mean: Optional[float] = None
186+
std: Optional[float] = None
187+
axes: Optional[Sequence[str]] = None
188+
eps: float = 1.0e-6
189+
190+
def get_required_dataset_statistics(self) -> Dict[str, Set[Measure]]:
191+
if self.mode == "per_dataset":
192+
return {self.tensor_name: {Mean(), Std()}}
193+
else:
194+
return {}
195+
196+
def get_required_sample_statistics(self) -> Dict[str, Set[Measure]]:
197+
if self.mode == "per_sample":
198+
return {self.tensor_name: {Mean(), Std()}}
199+
else:
200+
return {}
201+
202+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
203+
axes = None if self.axes is None else tuple(self.axes)
204+
if self.mode == "fixed":
205+
assert self.mean is not None and self.std is not None
206+
mean, std = self.mean, self.std
207+
elif self.mode == "per_sample":
208+
if axes:
209+
mean, std = tensor.mean(axes), tensor.std(axes)
210+
else:
211+
mean, std = tensor.mean(), tensor.std()
212+
elif self.mode == "per_dataset":
213+
mean = self.get_computed_dataset_statistics(self.tensor_name, Mean(axes))
214+
std = self.get_computed_dataset_statistics(self.tensor_name, Std(axes))
215+
else:
216+
raise ValueError(self.mode)
217+
218+
tensor = (tensor - mean) / (std + self.eps)
219+
return ensure_dtype(tensor, dtype="float32")

bioimageio/core/statistical_measures.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ class Std(Measure):
2020
@dataclass(frozen=True)
2121
class Percentile(Measure):
2222
n: float
23+
axes: Optional[Tuple[str]] = None
2324

2425
def __post_init__(self):
2526
assert self.n >= 0

0 commit comments

Comments
 (0)