Skip to content

Commit 5f13387

Browse files
committed
append EnsureDtype for postprocessing
1 parent f68150c commit 5f13387

File tree

2 files changed

+56
-22
lines changed

2 files changed

+56
-22
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from ._processing import (
1313
Binarize,
1414
Clip,
15+
EnsureDtype,
1516
Processing,
1617
ScaleLinear,
1718
ScaleMeanVariance,
@@ -62,6 +63,13 @@ def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTe
6263
if out.postprocessing is not missing
6364
]
6465

66+
# There is a difference between pre-and-postprocessing:
67+
# Pre-processing always returns float32, because its output is consumed by the model.
68+
# Post-processing, however, should return the dtype that is specified in the model spec.
69+
# todo: cast dtype for inputs before preprocessing? or check dtype?
70+
for out in outputs:
71+
self._post.append(EnsureDtype(tensor_name=out.name, dtype=out.data_type))
72+
6573
self._req_input_stats = {s: self._collect_required_stats(self._prep, s) for s in SCOPES}
6674
self._req_output_stats = {s: self._collect_required_stats(self._post, s) for s in SCOPES}
6775
if any(self._req_output_stats[s] for s in SCOPES):
@@ -84,9 +92,6 @@ def required_output_dataset_statistics(self) -> Dict[str, Set[Measure]]:
8492

8593
@property
8694
def computed_dataset_statistics(self) -> Dict[str, Dict[Measure, Any]]:
87-
if self._computed_dataset_stats is None:
88-
raise RuntimeError("Set computed dataset statistics first!")
89-
9095
return self._computed_dataset_stats
9196

9297
def apply_preprocessing(
@@ -141,12 +146,14 @@ def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]
141146
for proc in self._prep:
142147
proc.set_computed_dataset_statistics(self.computed_dataset_statistics)
143148

149+
@classmethod
144150
def compute_sample_statistics(
145-
self, tensors: Dict[str, xr.DataArray], measures: Dict[str, Set[Measure]]
151+
cls, tensors: Dict[str, xr.DataArray], measures: Dict[str, Set[Measure]]
146152
) -> Dict[str, Dict[Measure, Any]]:
147-
return {tname: self._compute_tensor_statistics(tensors[tname], ms) for tname, ms in measures.items()}
153+
return {tname: cls._compute_tensor_statistics(tensors[tname], ms) for tname, ms in measures.items()}
148154

149-
def _compute_tensor_statistics(self, tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
155+
@staticmethod
156+
def _compute_tensor_statistics(tensor: xr.DataArray, measures: Set[Measure]) -> Dict[Measure, Any]:
150157
ret = {}
151158
for measure in measures:
152159
if isinstance(measure, Mean):

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,12 @@
1-
from dataclasses import dataclass, field
2-
from typing import Any, Dict, Literal, Optional, Sequence, Set, get_args
1+
from dataclasses import dataclass, field, fields
2+
from typing import Any, Dict, Literal, Optional, Sequence, Set, Union, get_args
33

44
import numpy as np
55
import xarray as xr
66

77
from bioimageio.core.statistical_measures import Mean, Measure, Percentile, Std
88

99

10-
def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
11-
"""
12-
Convert array to a given datatype
13-
"""
14-
return tensor.astype(dtype)
15-
16-
1710
@dataclass
1811
class Processing:
1912
"""base class for all Pre- and Postprocessing transformations"""
@@ -37,7 +30,7 @@ def get_required_sample_statistics(self) -> Dict[str, Set[Measure]]:
3730

3831
def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
3932
"""helper to set computed statistics and check if they match the requirements"""
40-
for tensor_name, req_measures in self.get_required_dataset_statistics():
33+
for tensor_name, req_measures in self.get_required_dataset_statistics().items():
4134
comp_measures = computed.get(tensor_name, {})
4235
for req_measure in req_measures:
4336
if req_measure not in comp_measures:
@@ -46,7 +39,7 @@ def set_computed_dataset_statistics(self, computed: Dict[str, Dict[Measure, Any]
4639

4740
def set_computed_sample_statistics(self, computed: Dict[str, Dict[Measure, Any]]):
4841
"""helper to set computed statistics and check if they match the requirements"""
49-
for tensor_name, req_measures in self.get_required_sample_statistics():
42+
for tensor_name, req_measures in self.get_required_sample_statistics().items():
5043
comp_measures = computed.get(tensor_name, {})
5144
for req_measure in req_measures:
5245
if req_measure not in comp_measures:
@@ -69,15 +62,35 @@ def get_computed_sample_statistics(self, tensor_name: str, measure: Measure):
6962

7063
return ret
7164

65+
def __call__(self, tensor: xr.DataArray) -> xr.DataArray:
66+
return self.apply(tensor)
67+
7268
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
7369
"""apply processing to named tensors"""
7470
raise NotImplementedError
7571

7672
def __post_init__(self):
7773
"""validate common kwargs by their annotations"""
78-
if hasattr(self, "mode"):
79-
if self.mode not in get_args(self.mode):
80-
raise NotImplementedError(f"Unsupported mode {self.mode} for {self.__class__.__name__}: {self.mode}")
74+
self.computed_dataset_statistics = {}
75+
self.computed_sample_statistics = {}
76+
77+
for f in fields(self):
78+
if f.name == "mode":
79+
assert hasattr(self, "mode")
80+
if self.mode not in get_args(f.type):
81+
raise NotImplementedError(
82+
f"Unsupported mode {self.mode} for {self.__class__.__name__}: {self.mode}"
83+
)
84+
85+
86+
#
87+
# helpers
88+
#
89+
def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
90+
"""
91+
Convert array to a given datatype
92+
"""
93+
return tensor.astype(dtype)
8194

8295

8396
#
@@ -102,12 +115,20 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
102115
return ensure_dtype(tensor.clip(min=self.min, max=self.max), dtype="float32")
103116

104117

118+
@dataclass
119+
class EnsureDtype(Processing):
120+
dtype: str
121+
122+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
123+
return ensure_dtype(tensor, dtype=self.dtype)
124+
125+
105126
@dataclass
106127
class ScaleLinear(Processing):
107128
"""scale the tensor with a fixed multiplicative and additive factor"""
108129

109-
gain: float
110-
offset: float
130+
gain: Union[float, Sequence[float]]
131+
offset: Union[float, Sequence[float]]
111132
axes: Optional[Sequence[str]] = None
112133

113134
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
@@ -121,6 +142,12 @@ def apply(self, tensor: xr.DataArray) -> xr.DataArray:
121142

122143
return ensure_dtype(tensor * gain + offset, dtype="float32")
123144

145+
def __post_init__(self):
146+
super().__post_init__()
147+
if self.axes is None:
148+
assert isinstance(self.gain, (int, float))
149+
assert isinstance(self.offset, (int, float))
150+
124151

125152
@dataclass
126153
class ScaleMeanVariance(Processing):

0 commit comments

Comments
 (0)