@@ -61,7 +61,17 @@ def _test_predict_with_padding(model, tmp_path):
6161
6262 spec = load_resource_description (model )
6363 assert isinstance (spec , Model )
64- image = np .load (str (spec .test_inputs [0 ]))[0 , 0 ]
64+
65+ input_spec , output_spec = spec .inputs [0 ], spec .outputs [0 ]
66+ channel_axis = input_spec .axes .index ("c" )
67+ channel_first = channel_axis == 1
68+ assert output_spec .shape .scale [channel_axis ] == 0
69+ n_channels = int (2 * output_spec .shape .offset [channel_axis ])
70+
71+ if channel_first :
72+ image = np .load (str (spec .test_inputs [0 ]))[0 , 0 ]
73+ else :
74+ image = np .load (str (spec .test_inputs [0 ]))[0 , ..., 0 ]
6575 original_shape = image .shape
6676 assert image .ndim == 2
6777
@@ -72,12 +82,20 @@ def _test_predict_with_padding(model, tmp_path):
7282 imageio .imwrite (in_path , image )
7383
7484 def check_result ():
75- assert out_path .exists ()
76- res = imageio .imread (out_path )
77- assert res .shape == image .shape
85+ if n_channels == 1 :
86+ assert out_path .exists ()
87+ res = imageio .imread (out_path )
88+ assert res .shape == image .shape
89+ else :
90+ path = str (out_path )
91+ for c in range (n_channels ):
92+ channel_out_path = Path (path .replace (".tif" , f"-c{ c } .tif" ))
93+ assert channel_out_path .exists ()
94+ res = imageio .imread (channel_out_path )
95+ assert res .shape == image .shape
7896
7997 # test with dynamic padding
80- predict_image (model , in_path , out_path , padding = {"x" : 8 , "y" : 8 , "mode" : "dynamic" })
98+ predict_image (model , in_path , out_path , padding = {"x" : 16 , "y" : 16 , "mode" : "dynamic" })
8199 check_result ()
82100
83101 # test with fixed padding
0 commit comments