11import numpy as np
2+ import pytest
23import xarray as xr
34
5+ from bioimageio .core .prediction_pipeline ._measure_groups import compute_measures
46
5- def test_binarize_postprocessing ():
7+
8+ def test_binarize ():
69 from bioimageio .core .prediction_pipeline ._processing import Binarize
710
811 shape = (3 , 32 , 32 )
@@ -16,3 +19,52 @@ def test_binarize_postprocessing():
1619 binarize = Binarize ("data_name" , threshold = threshold )
1720 res = binarize (data )
1821 xr .testing .assert_allclose (res , exp )
22+
23+
24+ @pytest .mark .parametrize ("axes" , [None , tuple ("cy" ), tuple ("cyx" ), tuple ("x" )])
25+ def test_scale_mean_variance (axes ):
26+ from bioimageio .core .prediction_pipeline ._processing import ScaleMeanVariance
27+
28+ shape = (3 , 32 , 46 )
29+ ipt_axes = ("c" , "y" , "x" )
30+ np_data = np .random .rand (* shape )
31+ ipt_data = xr .DataArray (np_data , dims = ipt_axes )
32+ ref_data = xr .DataArray ((np_data * 2 ) + 3 , dims = ipt_axes )
33+
34+ scale_mean_variance = ScaleMeanVariance ("data_name" , reference_tensor = "ref_name" , axes = axes )
35+ required = scale_mean_variance .get_required_measures ()
36+ computed = compute_measures (required , sample = {"data_name" : ipt_data , "ref_name" : ref_data })
37+ scale_mean_variance .set_computed_measures (computed )
38+
39+ res = scale_mean_variance (ipt_data )
40+ xr .testing .assert_allclose (res , ref_data )
41+
42+
43+ @pytest .mark .parametrize ("axes" , [None , tuple ("cy" ), tuple ("y" ), tuple ("yx" )])
44+ def test_scale_mean_variance_per_channel (axes ):
45+ from bioimageio .core .prediction_pipeline ._processing import ScaleMeanVariance
46+
47+ shape = (3 , 32 , 46 )
48+ ipt_axes = ("c" , "y" , "x" )
49+ np_data = np .random .rand (* shape )
50+ ipt_data = xr .DataArray (np_data , dims = ipt_axes )
51+
52+ # set different mean, std per channel
53+ np_ref_data = np .stack ([d * i + i for i , d in enumerate (np_data , start = 2 )])
54+ print (np_ref_data .shape )
55+ ref_data = xr .DataArray (np_ref_data , dims = ipt_axes )
56+
57+ scale_mean_variance = ScaleMeanVariance ("data_name" , reference_tensor = "ref_name" , axes = axes )
58+ required = scale_mean_variance .get_required_measures ()
59+ computed = compute_measures (required , sample = {"data_name" : ipt_data , "ref_name" : ref_data })
60+ scale_mean_variance .set_computed_measures (computed )
61+
62+ res = scale_mean_variance (ipt_data )
63+
64+ if axes is not None and "c" not in axes :
65+ # mean,std per channel should match exactly
66+ xr .testing .assert_allclose (res , ref_data )
67+ else :
68+ # mean,std across channels should not match
69+ with pytest .raises (AssertionError ):
70+ xr .testing .assert_allclose (res , ref_data )
0 commit comments