Skip to content

Commit dab09b5

Browse files
committed
add AssertDtype Processing
1 parent 256399b commit dab09b5

File tree

3 files changed

+71
-14
lines changed

3 files changed

+71
-14
lines changed

bioimageio/core/prediction_pipeline/_combined_processing.py

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import dataclasses
2-
from typing import Dict, Iterable, List, NamedTuple, Optional, Sequence, TypedDict, Union
2+
from typing import Any, Dict, List, Optional, Sequence, Union
33

44
from bioimageio.core.resource_io import nodes
5-
from ._processing import EnsureDtype, KNOWN_PROCESSING, Processing, TensorName
5+
from ._processing import AssertDtype, EnsureDtype, KNOWN_PROCESSING, Processing, TensorName
66
from ._utils import ComputedMeasures, PER_DATASET, PER_SAMPLE, RequiredMeasures, Sample
77

88
try:
@@ -11,11 +11,19 @@
1111
from typing_extensions import Literal # type: ignore
1212

1313

14+
@dataclasses.dataclass
15+
class Processing:
16+
name: str
17+
kwargs: Dict[str, Any]
18+
19+
1420
@dataclasses.dataclass
1521
class TensorProcessingInfo:
16-
processing_steps: Union[List[nodes.Preprocessing], List[nodes.Postprocessing]]
17-
data_type_before: Optional[str] = None
18-
data_type_after: Optional[str] = None
22+
processing_steps: List[Processing]
23+
assert_dtype_before: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match
24+
ensure_dtype_before: Optional[str] = None # cast data type if needed
25+
assert_dtype_after: Optional[Union[str, Sequence[str]]] = None # throw AssertionError if data type doesn't match
26+
ensure_dtype_after: Optional[str] = None # throw AssertionError if data type doesn't match
1927

2028

2129
class CombinedProcessing:
@@ -26,16 +34,22 @@ def __init__(self, combine_tensors: Dict[TensorName, TensorProcessingInfo]):
2634

2735
# ensure all tensors have correct data type before any processing
2836
for tensor_name, info in combine_tensors.items():
29-
if info.data_type_before is not None:
30-
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.data_type_before))
37+
if info.assert_dtype_before is not None:
38+
self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_before))
39+
40+
if info.ensure_dtype_before is not None:
41+
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_before))
3142

3243
for tensor_name, info in combine_tensors.items():
3344
for step in info.processing_steps:
3445
self._procs.append(known[step.name](tensor_name=tensor_name, **step.kwargs))
3546

47+
if info.assert_dtype_after is not None:
48+
self._procs.append(AssertDtype(tensor_name=tensor_name, dtype=info.assert_dtype_after))
49+
3650
# ensure tensor has correct data type right after its processing
37-
if info.data_type_after is not None:
38-
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.data_type_after))
51+
if info.ensure_dtype_after is not None:
52+
self._procs.append(EnsureDtype(tensor_name=tensor_name, dtype=info.ensure_dtype_after))
3953

4054
self.required_measures: RequiredMeasures = self._collect_required_measures(self._procs)
4155
self.tensor_names = list(combine_tensors)
@@ -50,10 +64,15 @@ def from_tensor_specs(cls, tensor_specs: List[Union[nodes.InputTensor, nodes.Out
5064
# todo: cast dtype for inputs before preprocessing? or check dtype?
5165
assert ts.name not in combine_tensors
5266
if isinstance(ts, nodes.InputTensor):
53-
# todo: move preprocessing ensure_dtype here as data_type_after
54-
combine_tensors[ts.name] = TensorProcessingInfo(ts.preprocessing)
67+
# todo: assert nodes.InputTensor.dtype with assert_dtype_before?
68+
combine_tensors[ts.name] = TensorProcessingInfo(
69+
[Processing(p.name, kwargs=p.kwargs) for p in ts.preprocessing]
70+
)
5571
elif isinstance(ts, nodes.OutputTensor):
56-
combine_tensors[ts.name] = TensorProcessingInfo(ts.postprocessing, None, ts.data_type)
72+
combine_tensors[ts.name] = TensorProcessingInfo(
73+
[Processing(p.name, kwargs=p.kwargs) for p in ts.postprocessing],
74+
ensure_dtype_after=ts.data_type,
75+
)
5776
else:
5877
raise NotImplementedError(type(ts))
5978

bioimageio/core/prediction_pipeline/_processing.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@
22
see https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/preprocessing_spec_latest.md
33
and https://github.com/bioimage-io/spec-bioimage-io/blob/gh-pages/postprocessing_spec_latest.md
44
"""
5-
from dataclasses import dataclass, field, fields
6-
from typing import Mapping, Optional, Sequence, Type, Union
5+
import numbers
6+
from dataclasses import InitVar, dataclass, field, fields
7+
from typing import List, Mapping, Optional, Sequence, Tuple, Type, Union
78

9+
import numpy
810
import numpy as np
911
import xarray as xr
1012

@@ -105,6 +107,26 @@ def ensure_dtype(tensor: xr.DataArray, *, dtype) -> xr.DataArray:
105107
#
106108

107109

110+
@dataclass
111+
class AssertDtype(Processing):
112+
"""Helper Processing to assert dtype."""
113+
114+
dtype: Union[str, Sequence[str]] = MISSING
115+
assert_with: Tuple[Type[numpy.dtype], ...] = field(init=False)
116+
117+
def __post_init__(self):
118+
if isinstance(self.dtype, str):
119+
dtype = [self.dtype]
120+
else:
121+
dtype = self.dtype
122+
123+
self.assert_with = tuple(type(numpy.dtype(dt)) for dt in dtype)
124+
125+
def apply(self, tensor: xr.DataArray) -> xr.DataArray:
126+
assert isinstance(tensor.dtype, self.assert_with)
127+
return tensor
128+
129+
108130
@dataclass
109131
class Binarize(Processing):
110132
"""'output = tensor > threshold' (note: returns float array)."""

tests/prediction_pipeline/test_processing.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import dataclasses
22

3+
import numpy as np
34
import pytest
5+
import xarray as xr
46

57
from bioimageio.core.prediction_pipeline._processing import KNOWN_PROCESSING
68
from bioimageio.core.prediction_pipeline._utils import FIXED
@@ -11,6 +13,20 @@
1113
from typing_extensions import get_args # type: ignore
1214

1315

16+
def test_assert_dtype():
17+
from bioimageio.core.prediction_pipeline._processing import AssertDtype
18+
19+
proc = AssertDtype("test_tensor", dtype="uint8")
20+
tensor = xr.DataArray(np.zeros((1,), dtype="uint8"), dims=tuple("c"))
21+
out = proc(tensor)
22+
assert out is tensor
23+
24+
tensor = tensor.astype("uint16")
25+
with pytest.raises(AssertionError):
26+
out = proc(tensor)
27+
assert out is tensor
28+
29+
1430
@pytest.mark.parametrize(
1531
"proc",
1632
list(KNOWN_PROCESSING["pre"].values()) + list(KNOWN_PROCESSING["post"].values()),

0 commit comments

Comments
 (0)