Skip to content

Commit 03f4b70

Browse files
Ensure output dtype in post-processing
1 parent 267a424 commit 03f4b70

File tree

3 files changed

+27
-7
lines changed

3 files changed

+27
-7
lines changed

bioimageio/core/prediction_pipeline/_postprocessing.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ def scale_mean_variance(tensor: xr.DataArray, *, reference_tensor, mode="per_sam
4747
raise NotImplementedError
4848

4949

50-
# NOTE there is a subtle difference between pre-and-postprocessing:
51-
# pre-processing always returns float32, because the post-processing output is consumed
52-
# by the model. Post-processing, however, should return the dtype that is specified in the model spec
53-
# TODO I think the easiest way to implement this is to add dtype is an option to 'make_postprocessing'
54-
# and then apply 'ensure_dtype' to the result of the postprocessing chain
5550
KNOWN_POSTPROCESSING = {
5651
"binarize": ops.binarize,
5752
"clip": ops.clip,
@@ -62,7 +57,7 @@ def scale_mean_variance(tensor: xr.DataArray, *, reference_tensor, mode="per_sam
6257
}
6358

6459

65-
def make_postprocessing(spec: List[Postprocessing]) -> Transform:
60+
def make_postprocessing(spec: List[Postprocessing], dtype: str) -> Transform:
6661
"""
6762
:param preprocessing: bioimage-io spec node
6863
"""
@@ -78,4 +73,9 @@ def make_postprocessing(spec: List[Postprocessing]) -> Transform:
7873

7974
functions.append((fn, kwargs))
8075

76+
# There is a difference between pre-and-postprocessing:
77+
# Tre-processing always returns float32, because its output is consumed y the model.
78+
# Post-processing, however, should return the dtype that is specified in the model spec.
79+
functions.append((ops.ensure_dtype, {"dtype": dtype}))
80+
8181
return ops.chain(*functions)

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ def create_prediction_pipeline(
172172
postprocessing: List[Transform] = []
173173
for out in bioimageio_model.outputs:
174174
postprocessing_spec = [] if out.postprocessing is missing else out.postprocessing.copy()
175-
postprocessing.append(make_postprocessing(postprocessing_spec))
175+
postprocessing.append(make_postprocessing(postprocessing_spec, out.data_type))
176176

177177
return _PredictionPipelineImpl(
178178
name=bioimageio_model.name,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numpy as np
2+
import xarray as xr
3+
from bioimageio.core.resource_io.nodes import Postprocessing
4+
from bioimageio.core.prediction_pipeline._postprocessing import make_postprocessing
5+
6+
7+
def test_binarize_postprocessing():
8+
shape = (3, 32, 32)
9+
axes = ("c", "y", "x")
10+
np_data = np.random.rand(*shape)
11+
data = xr.DataArray(np_data, dims=axes)
12+
13+
threshold = 0.5
14+
exp = xr.DataArray(np_data > threshold, dims=axes)
15+
16+
for dtype in ("float32", "float64", "uint8", "uint16"):
17+
binarize = make_postprocessing(spec=[Postprocessing("binarize", kwargs={"threshold": threshold})], dtype=dtype)
18+
res = binarize(data)
19+
assert np.dtype(res.dtype) == np.dtype(dtype)
20+
xr.testing.assert_allclose(res, exp.astype(dtype))

0 commit comments

Comments
 (0)