Skip to content

Commit 13f1435

Browse files
Update pre-and-postprocessing WIP
1 parent 5d02803 commit 13f1435

File tree

3 files changed

+71
-8
lines changed

3 files changed

+71
-8
lines changed

bioimageio/core/prediction_pipeline/_postprocessing.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,60 @@
33
import xarray as xr
44
from bioimageio.core.resource_io.nodes import Postprocessing
55

6-
from ._preprocessing import binarize, chain
6+
from . import _preprocessing as ops
77
from ._types import Transform
88

99

10-
def sigmoid(tensor: xr.DataArray, **kwargs):
11-
if kwargs:
12-
raise NotImplementedError(f"Passed kwargs for sigmoid {kwargs}")
13-
return 1 / (1 + xr.ufuncs.exp(-tensor))
10+
def scale_range(
11+
tensor: xr.DataArray,
12+
*,
13+
reference_tensor=None,
14+
mode="per_sample",
15+
axes=None,
16+
min_percentile=0.0,
17+
max_percentile=100.0,
18+
) -> xr.DataArray:
1419

20+
# TODO if reference tensor is passed, we need to use it to compute quantiles instead of 'tensor'
21+
if reference_tensor is None:
22+
tensor_ = tensor
23+
else:
24+
raise NotImplementedError
1525

16-
KNOWN_POSTPROCESSING = {"binarize": binarize, "sigmoid": sigmoid}
26+
# valid modes according to spec: "per_sample", "per_dataset"
27+
# TODO implement per_dataset
28+
if mode != "per_sample":
29+
raise NotImplementedError(f"Unsupported mode for scale_range: {mode}")
30+
31+
if axes:
32+
axes = tuple(axes)
33+
v_lower = tensor_.quantile(min_percentile / 100.0, dim=axes)
34+
v_upper = tensor_.quantile(max_percentile / 100.0, dim=axes)
35+
else:
36+
v_lower = tensor_.quantile(min_percentile / 100.0)
37+
v_upper = tensor_.quantile(max_percentile / 100.0)
38+
39+
return ops.ensure_dtype((tensor - v_lower) / v_upper, dtype="float32")
40+
41+
42+
# TODO scale the tensor s.t. it matches the mean and variance of the reference tensor
43+
def scale_mean_variance(tensor: xr.DataArray, *, reference_tensor, mode="per_sample"):
44+
raise NotImplementedError
45+
46+
47+
# NOTE there is a subtle difference between pre-and-postprocessing:
48+
# pre-processing always returns float32, because the post-processing output is consumed
49+
# by the model. Post-processing, however, should return the dtype that is specified in the model spec
50+
# TODO I think the easiest way to implement this is to add dtype is an option to 'make_postprocessing'
51+
# and then apply 'ensure_dtype' to the result of the postprocessing chain
52+
KNOWN_POSTPROCESSING = {
53+
"binarize": ops.binarize,
54+
"clip": ops.clip,
55+
"scale_linear": ops.scale_linear,
56+
"scale_range": ops.scale_range,
57+
"sigmoid": ops.sigmoid,
58+
"zero_mean_unit_variance": ops.zero_mean_unit_variance
59+
}
1760

1861

1962
def make_postprocessing(spec: List[Postprocessing]) -> Transform:
@@ -32,4 +75,4 @@ def make_postprocessing(spec: List[Postprocessing]) -> Transform:
3275

3376
functions.append((fn, kwargs))
3477

35-
return chain(*functions)
78+
return ops.chain(*functions)

bioimageio/core/prediction_pipeline/_preprocessing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,19 @@ def ensure_dtype(tensor: xr.DataArray, *, dtype):
7272
return tensor.astype(dtype)
7373

7474

75+
def sigmoid(tensor: xr.DataArray, **kwargs):
76+
if kwargs:
77+
raise NotImplementedError(f"Passed kwargs for sigmoid {kwargs}")
78+
return 1.0 / (1.0 + xr.ufuncs.exp(-tensor))
79+
80+
7581
KNOWN_PREPROCESSING: Dict[PreprocessingName, Transform] = {
7682
"scale_linear": scale_linear,
7783
"zero_mean_unit_variance": zero_mean_unit_variance,
7884
"binarize": binarize,
7985
"clip": clip,
80-
"scale_range": scale_range
86+
"scale_range": scale_range,
87+
"sigmoid": sigmoid
8188
# "__tiktorch_ensure_dtype": ensure_dtype,
8289
}
8390

tests/prediction_pipeline/test_preprocessing.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,16 @@ def test_scale_range_axes():
151151
preprocessing = make_preprocessing([scale_range_spec])
152152
result = preprocessing(data)
153153
xr.testing.assert_allclose(expected, result)
154+
155+
156+
def test_sigmoid():
157+
shape = (3, 32, 32)
158+
axes = ("c", "y", "x")
159+
np_data = np.random.rand(*shape)
160+
data = xr.DataArray(np_data, dims=axes)
161+
162+
sigmoid = make_preprocessing([Preprocessing("sigmoid")])
163+
res = sigmoid(data)
164+
165+
exp = 1. / (1 + np.exp(-np_data))
166+
xr.testing.assert_allclose(res, exp)

0 commit comments

Comments
 (0)