Skip to content

Commit 6f80afe

Browse files
authored
Merge pull request #107 from bioimage-io/prepost
Update pre-and-postprocessing
2 parents 1f5e99e + 03f4b70 commit 6f80afe

File tree

5 files changed

+96
-10
lines changed

5 files changed

+96
-10
lines changed

bioimageio/core/prediction_pipeline/_postprocessing.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,61 @@
33
import xarray as xr
44
from bioimageio.core.resource_io.nodes import Postprocessing
55

6-
from ._preprocessing import binarize, chain
6+
from . import _preprocessing as ops
77
from ._types import Transform
88

99

10-
def sigmoid(tensor: xr.DataArray, **kwargs):
11-
if kwargs:
12-
raise NotImplementedError(f"Passed kwargs for sigmoid {kwargs}")
13-
return 1 / (1 + xr.ufuncs.exp(-tensor))
10+
# TODO how do we implement reference_tensor?
1411

1512

16-
KNOWN_POSTPROCESSING = {"binarize": binarize, "sigmoid": sigmoid}
13+
def scale_range(
14+
tensor: xr.DataArray,
15+
*,
16+
reference_tensor=None,
17+
mode="per_sample",
18+
axes=None,
19+
min_percentile=0.0,
20+
max_percentile=100.0,
21+
) -> xr.DataArray:
1722

23+
# TODO if reference tensor is passed, we need to use it to compute quantiles instead of 'tensor'
24+
if reference_tensor is None:
25+
tensor_ = tensor
26+
else:
27+
raise NotImplementedError
1828

19-
def make_postprocessing(spec: List[Postprocessing]) -> Transform:
29+
# valid modes according to spec: "per_sample", "per_dataset"
30+
# TODO implement per_dataset
31+
if mode != "per_sample":
32+
raise NotImplementedError(f"Unsupported mode for scale_range: {mode}")
33+
34+
if axes:
35+
axes = tuple(axes)
36+
v_lower = tensor_.quantile(min_percentile / 100.0, dim=axes)
37+
v_upper = tensor_.quantile(max_percentile / 100.0, dim=axes)
38+
else:
39+
v_lower = tensor_.quantile(min_percentile / 100.0)
40+
v_upper = tensor_.quantile(max_percentile / 100.0)
41+
42+
return ops.ensure_dtype((tensor - v_lower) / v_upper, dtype="float32")
43+
44+
45+
# TODO scale the tensor s.t. it matches the mean and variance of the reference tensor
46+
def scale_mean_variance(tensor: xr.DataArray, *, reference_tensor, mode="per_sample"):
47+
raise NotImplementedError
48+
49+
50+
KNOWN_POSTPROCESSING = {
51+
"binarize": ops.binarize,
52+
"clip": ops.clip,
53+
"scale_linear": ops.scale_linear,
54+
"scale_range": ops.scale_range,
55+
"sigmoid": ops.sigmoid,
56+
"zero_mean_unit_variance": ops.zero_mean_unit_variance,
57+
}
58+
59+
60+
def make_postprocessing(spec: List[Postprocessing], dtype: str) -> Transform:
2061
"""
2162
:param preprocessing: bioimage-io spec node
2263
"""
@@ -32,4 +73,9 @@ def make_postprocessing(spec: List[Postprocessing]) -> Transform:
3273

3374
functions.append((fn, kwargs))
3475

35-
return chain(*functions)
76+
# There is a difference between pre-and-postprocessing:
77+
# Tre-processing always returns float32, because its output is consumed y the model.
78+
# Post-processing, however, should return the dtype that is specified in the model spec.
79+
functions.append((ops.ensure_dtype, {"dtype": dtype}))
80+
81+
return ops.chain(*functions)

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_prediction_pipeline(
172172
postprocessing: List[Transform] = []
173173
for out in bioimageio_model.outputs:
174174
postprocessing_spec = [] if out.postprocessing is missing else out.postprocessing.copy()
175-
postprocessing.append(make_postprocessing(postprocessing_spec))
175+
postprocessing.append(make_postprocessing(postprocessing_spec, out.data_type))
176176

177177
return _PredictionPipelineImpl(
178178
name=bioimageio_model.name,

bioimageio/core/prediction_pipeline/_preprocessing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,19 @@ def ensure_dtype(tensor: xr.DataArray, *, dtype):
7272
return tensor.astype(dtype)
7373

7474

75+
def sigmoid(tensor: xr.DataArray, **kwargs):
76+
if kwargs:
77+
raise NotImplementedError(f"Passed kwargs for sigmoid {kwargs}")
78+
return 1.0 / (1.0 + xr.ufuncs.exp(-tensor))
79+
80+
7581
KNOWN_PREPROCESSING: Dict[PreprocessingName, Transform] = {
7682
"scale_linear": scale_linear,
7783
"zero_mean_unit_variance": zero_mean_unit_variance,
7884
"binarize": binarize,
7985
"clip": clip,
80-
"scale_range": scale_range
86+
"scale_range": scale_range,
87+
"sigmoid": sigmoid
8188
# "__tiktorch_ensure_dtype": ensure_dtype,
8289
}
8390

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
import xarray as xr
3+
from bioimageio.core.resource_io.nodes import Postprocessing
4+
from bioimageio.core.prediction_pipeline._postprocessing import make_postprocessing
5+
6+
7+
def test_binarize_postprocessing():
8+
shape = (3, 32, 32)
9+
axes = ("c", "y", "x")
10+
np_data = np.random.rand(*shape)
11+
data = xr.DataArray(np_data, dims=axes)
12+
13+
threshold = 0.5
14+
exp = xr.DataArray(np_data > threshold, dims=axes)
15+
16+
for dtype in ("float32", "float64", "uint8", "uint16"):
17+
binarize = make_postprocessing(spec=[Postprocessing("binarize", kwargs={"threshold": threshold})], dtype=dtype)
18+
res = binarize(data)
19+
assert np.dtype(res.dtype) == np.dtype(dtype)
20+
xr.testing.assert_allclose(res, exp.astype(dtype))

tests/prediction_pipeline/test_preprocessing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,16 @@ def test_scale_range_axes():
151151
preprocessing = make_preprocessing([scale_range_spec])
152152
result = preprocessing(data)
153153
xr.testing.assert_allclose(expected, result)
154+
155+
156+
def test_sigmoid():
157+
shape = (3, 32, 32)
158+
axes = ("c", "y", "x")
159+
np_data = np.random.rand(*shape)
160+
data = xr.DataArray(np_data, dims=axes)
161+
162+
sigmoid = make_preprocessing([Preprocessing("sigmoid", kwargs={})])
163+
res = sigmoid(data)
164+
165+
exp = xr.DataArray(1.0 / (1 + np.exp(-np_data)), dims=axes)
166+
xr.testing.assert_allclose(res, exp)

0 commit comments

Comments
 (0)