Skip to content

Commit 3cf3611

Browse files
authored
Merge pull request #300 from bioimage-io/impl_scale_mean_var
Implement scale_mean_var
2 parents b0ceac8 + e25b27a commit 3cf3611

File tree

6 files changed

+213
-59
lines changed

6 files changed

+213
-59
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 68 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from typing import List, Optional, Sequence, Union
1+
import dataclasses
2+
from typing import Any, Dict, List, Optional, Sequence, Union
23

34
from bioimageio.core.resource_io import nodes
4-
from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing
5+
from ._processing import AssertDtype, EnsureDtype, KNOWN_PROCESSING, Processing, TensorName
56
from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample
67

78
try:
@@ -10,44 +11,78 @@
1011
from typing_extensions import Literal # type: ignore
1112

1213

14+
@dataclasses.dataclass
15+
class ProcessingInfoStep:
16+
name: str
17+
kwargs: Dict[str, Any]
18+
19+
20+
@dataclasses.dataclass
21+
class ProcessingInfo:
22+
steps: List[ProcessingInfoStep]
23+
assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match
24+
ensure_dtype_before: Optional[str] = None # cast data type if needed
25+
assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match
26+
ensure_dtype_after: Optional[str] = None # throw AssertionError if data type doesn't match
27+
28+
1329
class CombinedProcessing:
14-
def __init__(self, tensor_specs: Union[List[nodes.InputTensor], List[nodes.OutputTensor]]):
15-
PRE: Literal["pre"] = "pre"
16-
POST: Literal["post"] = "post"
17-
proc_prefix: Optional[Literal["pre", "post"]] = None
30+
def __init__(self, combine_tensors: Dict[TensorName, ProcessingInfo]):
1831
self._procs = []
19-
for t in tensor_specs:
20-
if isinstance(t, nodes.InputTensor):
21-
steps = t.preprocessing or []
22-
if proc_prefix is not None and proc_prefix != PRE:
23-
raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}")
24-
25-
proc_prefix = PRE
26-
elif isinstance(t, nodes.OutputTensor):
27-
steps = t.postprocessing or []
28-
if proc_prefix is not None and proc_prefix != POST:
29-
raise ValueError(f"Invalid mixed input/output tensor specs: {tensor_specs}")
30-
31-
proc_prefix = POST
32-
else:
33-
raise NotImplementedError(t)
32+
known = dict(KNOWN_PROCESSING["pre"])
33+
known.update(KNOWN_PROCESSING["post"])
34+
35+
# ensure all tensors have correct data type before any processing
36+
for tensor_name, info in combine_tensors.items():
37+
if info.assert_dtype_before is not None:
38+
self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_before))
3439

35-
for step in steps:
36-
self._procs.append(KNOWN_PROCESSING[proc_prefix][step.name](tensor_name=t.name, **step.kwargs))
40+
if info.ensure_dtype_before is not None:
41+
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_before))
3742

38-
# There is a difference between pre-and-postprocessing:
39-
# Pre-processing always returns float32, because its output is consumed by the model.
40-
# Post-processing, however, should return the dtype that is specified in the model spec.
41-
# todo: cast dtype for inputs before preprocessing? or check dtype?
42-
if proc_prefix == POST:
43-
for t in tensor_specs:
44-
self._procs.append(EnsureDtype(tensor_name=t.name, dtype=t.data_type))
43+
for tensor_name, info in combine_tensors.items():
44+
for step in info.steps:
45+
self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs))
46+
47+
if info.assert_dtype_after is not None:
48+
self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_after))
49+
50+
# ensure tensor has correct data type right after its processing
51+
if info.ensure_dtype_after is not None:
52+
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_after))
4553

4654
self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs)
47-
if proc_prefix == POST and self.required_measures[PER_DATASET]:
48-
raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented")
55+
self.tensor_names = list(combine_tensors)
56+
57+
@classmethod
58+
def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.OutputTensor]]):
59+
combine_tensors = {}
60+
for ts in tensor_specs:
61+
# There is a difference between pre-and postprocessing:
62+
# After preprocessing we ensure float32, because the output is consumed by the model.
63+
# After postprocessing the dtype that is specified in the model spec needs to be ensured.
64+
assert ts.name not in combine_tensors
65+
if isinstance(ts, nodes.InputTensor):
66+
# todo: assert nodes.InputTensor.dtype with assert_dtype_before?
67+
# todo: in the long run we do not want to limit model inputs to float32...
68+
combine_tensors[ts.name] = ProcessingInfo(
69+
[ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.preprocessing or []],
70+
ensure_dtype_after="float32",
71+
)
72+
elif isinstance(ts, nodes.OutputTensor):
73+
combine_tensors[ts.name] = ProcessingInfo(
74+
[ProcessingInfoStep(p.name, kwargs=p.kwargs) for p in ts.postprocessing or []],
75+
ensure_dtype_after=ts.data_type,
76+
)
77+
else:
78+
raise NotImplementedError(type(ts))
79+
80+
inst = cls(combine_tensors)
81+
for ts in tensor_specs:
82+
if isinstance(ts, nodes.OutputTensor) and ts.name in inst.required_measures[PER_DATASET]:
83+
raise NotImplementedError("computing statistics for output tensors per dataset is not yet implemented")
4984

50-
self.tensor_names = [t.name for t in tensor_specs]
85+
return inst
5186

5287
def apply(self, sample: Sample, computed_measures: ComputedMeasures) -> None:
5388
for proc in self._procs:

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def create_prediction_pipeline(
213213
ipts = [resolve_raw_node(s, nodes) for s in bioimageio_model.inputs]
214214
outs = [resolve_raw_node(s, nodes) for s in bioimageio_model.outputs]
215215

216-
preprocessing = CombinedProcessing(ipts)
216+
preprocessing = CombinedProcessing.from_tensor_specs(ipts)
217217

218218
def sample_dataset():
219219
for tensors in dataset_for_initial_statistics:
@@ -225,7 +225,7 @@ def sample_dataset():
225225
update_dataset_stats_after_n_samples=update_dataset_stats_after_n_samples,
226226
update_dataset_stats_for_n_samples=update_dataset_stats_for_n_samples,
227227
)
228-
postprocessing = CombinedProcessing(outs)
228+
postprocessing = CombinedProcessing.from_tensor_specs(outs)
229229
out_stats = StatsState(
230230
postprocessing.required_measures,
231231
dataset=tuple(),

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 73 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1-
from dataclasses import dataclass, field, fields
2-
from typing import Mapping, Optional, Sequence, Type, Union
3-
1+
"""Here pre- and postprocessing operations are implemented according to their definitions in bioimageio.spec:
2+
see https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/preprocessing_spec_latest.md
3+
and https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/postprocessing_spec_latest.md
4+
"""
5+
import numbers
6+
from dataclasses import InitVar, dataclass, field, fields
7+
from typing import List, Mapping, Optional, Sequence, Tuple, Type, Union
8+
9+
import numpy
410
import numpy as np
511
import xarray as xr
612

@@ -33,7 +39,7 @@ def _get_fixed(
3339

3440
@dataclass
3541
class Processing:
36-
"""base class for all Pre- and Postprocessing transformations"""
42+
"""base class for all Pre- and Postprocessing transformations."""
3743

3844
tensor_name: str
3945
# todo: in python>=3.10 we should use dataclasses.KW_ONLY instead of MISSING (see child classes) to make inheritance work properly
@@ -87,48 +93,64 @@ def __post_init__(self):
8793

8894

8995
#
90-
# helpers
96+
# Pre- and Postprocessing implementations
9197
#
92-
def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
93-
"""
94-
Convert array to a given datatype
95-
"""
96-
return tensor.astype(dtype)
9798

9899

99-
#
100-
# Pre- and Postprocessing implementations
101-
#
100+
@dataclass
101+
class AssertDtype(Processing):
102+
"""Helper Processing to assert dtype."""
103+
104+
dtype: Union[str, Sequence[str]] = MISSING
105+
assert_with: Tuple[Type[numpy.dtype], ...] = field(init=False)
106+
107+
def __post_init__(self):
108+
if isinstance(self.dtype, str):
109+
dtype = [self.dtype]
110+
else:
111+
dtype = self.dtype
112+
113+
self.assert_with = tuple(type(numpy.dtype(dt)) for dt in dtype)
114+
115+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
116+
assert isinstance(tensor.dtype, self.assert_with)
117+
return tensor
102118

103119

104120
@dataclass
105121
class Binarize(Processing):
122+
"""'output = tensor > threshold'."""
123+
106124
threshold: float = MISSING # make dataclass inheritance work for py<3.10 by using an explicit MISSING value.
107125

108126
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
109-
return ensure_dtype(tensor > self.threshold, dtype="float32")
127+
return tensor > self.threshold
110128

111129

112130
@dataclass
113131
class Clip(Processing):
132+
"""Limit tensor values to [min, max]."""
133+
114134
min: float = MISSING
115135
max: float = MISSING
116136

117137
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
118-
return ensure_dtype(tensor.clip(min=self.min, max=self.max), dtype="float32")
138+
return tensor.clip(min=self.min, max=self.max)
119139

120140

121141
@dataclass
122142
class EnsureDtype(Processing):
143+
"""Helper Processing to cast dtype if needed."""
144+
123145
dtype: str = MISSING
124146

125147
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
126-
return ensure_dtype(tensor, dtype=self.dtype)
148+
return tensor.astype(self.dtype)
127149

128150

129151
@dataclass
130152
class ScaleLinear(Processing):
131-
"""scale the tensor with a fixed multiplicative and additive factor"""
153+
"""Scale the tensor with a fixed multiplicative and additive factor."""
132154

133155
gain: Union[float, Sequence[float]] = MISSING
134156
offset: Union[float, Sequence[float]] = MISSING
@@ -143,7 +165,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
143165
gain = self.gain
144166
offset = self.offset
145167

146-
return ensure_dtype(tensor * gain + offset, dtype="float32")
168+
return tensor * gain + offset
147169

148170
def __post_init__(self):
149171
super().__post_init__()
@@ -154,11 +176,37 @@ def __post_init__(self):
154176

155177
@dataclass
156178
class ScaleMeanVariance(Processing):
157-
...
179+
"""Scale the tensor s.t. its mean and variance match a reference tensor."""
180+
181+
mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE
182+
reference_tensor: TensorName = MISSING
183+
axes: Optional[Sequence[str]] = None
184+
eps: float = 1e-6
185+
186+
def get_required_measures(self) -> RequiredMeasures:
187+
axes = None if self.axes is None else tuple(self.axes)
188+
return {
189+
self.mode: {
190+
self.tensor_name: {Mean(axes=axes), Std(axes=axes)},
191+
self.reference_tensor: {Mean(axes=axes), Std(axes=axes)},
192+
}
193+
}
194+
195+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
196+
axes = None if self.axes is None else tuple(self.axes)
197+
assert self.mode in (PER_SAMPLE, PER_DATASET)
198+
mean = self.get_computed_measure(self.tensor_name, Mean(axes), mode=self.mode)
199+
std = self.get_computed_measure(self.tensor_name, Std(axes), mode=self.mode)
200+
ref_mean = self.get_computed_measure(self.reference_tensor, Mean(axes), mode=self.mode)
201+
ref_std = self.get_computed_measure(self.reference_tensor, Std(axes), mode=self.mode)
202+
203+
return (tensor - mean) / (std + self.eps) * (ref_std + self.eps) + ref_mean
158204

159205

160206
@dataclass
161207
class ScaleRange(Processing):
208+
"""Scale with percentiles."""
209+
162210
mode: Literal[SampleMode, DatasetMode] = PER_SAMPLE
163211
axes: Optional[Sequence[str]] = None
164212
min_percentile: float = 0.0
@@ -177,7 +225,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
177225
v_lower = self.get_computed_measure(ref_name, Percentile(self.min_percentile, axes=axes))
178226
v_upper = self.get_computed_measure(ref_name, Percentile(self.max_percentile, axes=axes))
179227

180-
return ensure_dtype((tensor - v_lower) / (v_upper - v_lower + self.eps), dtype="float32")
228+
return (tensor - v_lower) / (v_upper - v_lower + self.eps)
181229

182230
def __post_init__(self):
183231
super().__post_init__()
@@ -186,12 +234,16 @@ def __post_init__(self):
186234

187235
@dataclass
188236
class Sigmoid(Processing):
237+
"""1 / (1 + e^(-tensor))."""
238+
189239
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
190240
return 1.0 / (1.0 + np.exp(-tensor))
191241

192242

193243
@dataclass
194244
class ZeroMeanUnitVariance(Processing):
245+
"""normalize to zero mean, unit variance."""
246+
195247
mode: Mode = PER_SAMPLE
196248
mean: Optional[Union[float, Sequence[float]]] = None
197249
std: Optional[Union[float, Sequence[float]]] = None
@@ -218,8 +270,7 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
218270
else:
219271
raise ValueError(self.mode)
220272

221-
tensor = (tensor - mean) / (std + self.eps)
222-
return ensure_dtype(tensor, dtype="float32")
273+
return (tensor - mean) / (std + self.eps)
223274

224275

225276
_KnownProcessing = TypedDict(

tests/prediction_pipeline/test_combined_processing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def test_postprocessing_dtype():
2525
postprocessing=[nodes.Postprocessing("binarize", dict(threshold=threshold))],
2626
)
2727
]
28-
com_proc = CombinedProcessing(outputs)
28+
com_proc = CombinedProcessing.from_tensor_specs(outputs)
2929

3030
sample = {"out1": data}
3131
com_proc.apply(sample, {})

0 commit comments

Comments
 (0)