Skip to content

Commit 34dea7e

Browse files
committed
add test_postprocessing_dtype
1 parent 0cda1f0 commit 34dea7e

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import numpy as np
2+
import xarray as xr
3+
4+
from bioimageio.core.resource_io import nodes
5+
6+
7+
def test_postprocessing_dtype():
8+
from bioimageio.core.prediction_pipeline._combined_processing import CombinedProcessing
9+
10+
shape = [3, 32, 32]
11+
axes = ("c", "y", "x")
12+
np_data = np.random.rand(*shape)
13+
data = xr.DataArray(np_data, dims=axes)
14+
15+
inputs = []
16+
17+
threshold = 0.5
18+
exp = xr.DataArray(np_data > threshold, dims=axes)
19+
20+
for dtype in ("float32", "float64", "uint8", "uint16"):
21+
outputs = [
22+
nodes.OutputTensor(
23+
"out1",
24+
data_type=dtype,
25+
axes=axes,
26+
shape=shape,
27+
postprocessing=[nodes.Postprocessing("binarize", dict(threshold=threshold))],
28+
)
29+
]
30+
com_proc = CombinedProcessing(inputs, outputs)
31+
32+
res, _ = com_proc.apply_postprocessing(data, input_sample_statistics={})
33+
assert len(res) == 1
34+
res = res[0]
35+
assert np.dtype(res.dtype) == np.dtype(dtype)
36+
xr.testing.assert_allclose(res, exp.astype(dtype))

0 commit comments

Comments
 (0)