Skip to content

Commit ddc2f6d

Browse files
Fix test_predict_image_with_padding_channel_last
1 parent 7f51b19 commit ddc2f6d

File tree

1 file changed

+23
-5
lines changed

1 file changed

+23
-5
lines changed

tests/test_prediction.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,17 @@ def _test_predict_with_padding(model, tmp_path):
6161

6262
spec = load_resource_description(model)
6363
assert isinstance(spec, Model)
64-
image = np.load(str(spec.test_inputs[0]))[0, 0]
64+
65+
input_spec, output_spec = spec.inputs[0], spec.outputs[0]
66+
channel_axis = input_spec.axes.index("c")
67+
channel_first = channel_axis == 1
68+
assert output_spec.shape.scale[channel_axis] == 0
69+
n_channels = int(2 * output_spec.shape.offset[channel_axis])
70+
71+
if channel_first:
72+
image = np.load(str(spec.test_inputs[0]))[0, 0]
73+
else:
74+
image = np.load(str(spec.test_inputs[0]))[0, ..., 0]
6575
original_shape = image.shape
6676
assert image.ndim == 2
6777

@@ -72,12 +82,20 @@ def _test_predict_with_padding(model, tmp_path):
7282
imageio.imwrite(in_path, image)
7383

7484
def check_result():
75-
assert out_path.exists()
76-
res = imageio.imread(out_path)
77-
assert res.shape == image.shape
85+
if n_channels == 1:
86+
assert out_path.exists()
87+
res = imageio.imread(out_path)
88+
assert res.shape == image.shape
89+
else:
90+
path = str(out_path)
91+
for c in range(n_channels):
92+
channel_out_path = Path(path.replace(".tif", f"-c{c}.tif"))
93+
assert channel_out_path.exists()
94+
res = imageio.imread(channel_out_path)
95+
assert res.shape == image.shape
7896

7997
# test with dynamic padding
80-
predict_image(model, in_path, out_path, padding={"x": 8, "y": 8, "mode": "dynamic"})
98+
predict_image(model, in_path, out_path, padding={"x": 16, "y": 16, "mode": "dynamic"})
8199
check_result()
82100

83101
# test with fixed padding

0 commit comments

Comments
 (0)