88from bioimageio .core .resource_io .nodes import Model
99
1010
11- def test_test_model (unet2d_nuclei_broad_model ):
11+ def test_test_model (any_model ):
1212 from bioimageio .core .resource_tests import test_model
1313
14- assert test_model (unet2d_nuclei_broad_model )
14+ assert test_model (any_model )
1515
1616
17- def test_test_resource (unet2d_nuclei_broad_model ):
17+ def test_test_resource (any_model ):
1818 from bioimageio .core .resource_tests import test_resource
1919
20- assert test_resource (unet2d_nuclei_broad_model )
20+ assert test_resource (any_model )
2121
2222
23- def test_predict_image (unet2d_fixed_shape_or_not , tmpdir ):
24- any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
23+ def test_predict_image (any_model , tmpdir ):
2524 from bioimageio .core .prediction import predict_image
2625
2726 spec = load_resource_description (any_model )
@@ -57,46 +56,81 @@ def test_predict_image_with_weight_format(unet2d_fixed_shape_or_not, tmpdir):
5756 assert_array_almost_equal (res , exp , decimal = 4 )
5857
5958
60- def test_predict_image_with_padding (unet2d_fixed_shape_or_not , tmp_path ):
61- any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
59+ def _test_predict_with_padding (model , tmp_path ):
6260 from bioimageio .core .prediction import predict_image
6361
64- spec = load_resource_description (any_model )
62+ spec = load_resource_description (model )
6563 assert isinstance (spec , Model )
66- image = np .load (str (spec .test_inputs [0 ]))[0 , 0 ]
64+
65+ input_spec , output_spec = spec .inputs [0 ], spec .outputs [0 ]
66+ channel_axis = input_spec .axes .index ("c" )
67+ channel_first = channel_axis == 1
68+
69+ image = np .load (str (spec .test_inputs [0 ]))
70+ assert image .shape [channel_axis ] == 1
71+ if channel_first :
72+ image = image [0 , 0 ]
73+ else :
74+ image = image [0 , ..., 0 ]
6775 original_shape = image .shape
6876 assert image .ndim == 2
6977
78+ if isinstance (output_spec .shape , list ):
79+ n_channels = output_spec .shape [channel_axis ]
80+ else :
81+ scale = output_spec .shape .scale [channel_axis ]
82+ offset = output_spec .shape .offset [channel_axis ]
83+ in_channels = 1
84+ n_channels = int (2 * offset + scale * in_channels )
85+
7086 # write the padded image
7187 image = image [3 :- 2 , 1 :- 12 ]
7288 in_path = tmp_path / "in.tif"
7389 out_path = tmp_path / "out.tif"
7490 imageio .imwrite (in_path , image )
7591
7692 def check_result ():
77- assert out_path .exists ()
78- res = imageio .imread (out_path )
79- assert res .shape == image .shape
93+ if n_channels == 1 :
94+ assert out_path .exists ()
95+ res = imageio .imread (out_path )
96+ assert res .shape == image .shape
97+ else :
98+ path = str (out_path )
99+ for c in range (n_channels ):
100+ channel_out_path = Path (path .replace (".tif" , f"-c{ c } .tif" ))
101+ assert channel_out_path .exists ()
102+ res = imageio .imread (channel_out_path )
103+ assert res .shape == image .shape
80104
81105 # test with dynamic padding
82- predict_image (any_model , in_path , out_path , padding = {"x" : 8 , "y" : 8 , "mode" : "dynamic" })
106+ predict_image (model , in_path , out_path , padding = {"x" : 16 , "y" : 16 , "mode" : "dynamic" })
83107 check_result ()
84108
85109 # test with fixed padding
86110 predict_image (
87- any_model , in_path , out_path , padding = {"x" : original_shape [0 ], "y" : original_shape [1 ], "mode" : "fixed" }
111+ model , in_path , out_path , padding = {"x" : original_shape [0 ], "y" : original_shape [1 ], "mode" : "fixed" }
88112 )
89113 check_result ()
90114
91115 # test with automated padding
92- predict_image (any_model , in_path , out_path , padding = True )
116+ predict_image (model , in_path , out_path , padding = True )
93117 check_result ()
94118
95119
96- def test_predict_image_with_tiling (unet2d_nuclei_broad_model , tmp_path ):
120+ # prediction with padding with the parameters above may not be suited for any model
121+ # so we only run it for the pytorch unet2d here
122+ def test_predict_image_with_padding (unet2d_fixed_shape_or_not , tmp_path ):
123+ _test_predict_with_padding (unet2d_fixed_shape_or_not , tmp_path )
124+
125+
126+ def test_predict_image_with_padding_channel_last (stardist , tmp_path ):
127+ _test_predict_with_padding (stardist , tmp_path )
128+
129+
130+ def _test_predict_image_with_tiling (model , tmp_path ):
97131 from bioimageio .core .prediction import predict_image
98132
99- spec = load_resource_description (unet2d_nuclei_broad_model )
133+ spec = load_resource_description (model )
100134 assert isinstance (spec , Model )
101135 inputs = spec .test_inputs
102136 assert len (inputs ) == 1
@@ -114,14 +148,24 @@ def check_result():
114148
115149 # with tiling config
116150 tiling = {"halo" : {"x" : 32 , "y" : 32 }, "tile" : {"x" : 256 , "y" : 256 }}
117- predict_image (unet2d_nuclei_broad_model , inputs , [out_path ], tiling = tiling )
151+ predict_image (model , inputs , [out_path ], tiling = tiling )
118152 check_result ()
119153
120154 # with tiling determined from spec
121- predict_image (unet2d_nuclei_broad_model , inputs , [out_path ], tiling = True )
155+ predict_image (model , inputs , [out_path ], tiling = True )
122156 check_result ()
123157
124158
159+ # prediction with tiling with the parameters above may not be suited for any model
160+ # so we only run it for the pytorch unet2d here
161+ def test_predict_image_with_tiling (unet2d_nuclei_broad_model , tmp_path ):
162+ _test_predict_image_with_tiling (unet2d_nuclei_broad_model , tmp_path )
163+
164+
165+ def test_predict_image_with_tiling_channel_last (stardist , tmp_path ):
166+ _test_predict_image_with_tiling (stardist , tmp_path )
167+
168+
125169def test_predict_images (unet2d_nuclei_broad_model , tmp_path ):
126170 from bioimageio .core .prediction import predict_images
127171
0 commit comments