Skip to content

Commit 30ce61c

Browse files
committed
Merge branch 'main' into unload
2 parents 43687ef + 29e28d7 commit 30ce61c

File tree

3 files changed

+126
-18
lines changed

3 files changed

+126
-18
lines changed

bioimageio/core/resource_tests.py

Lines changed: 82 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,22 @@
11
import traceback
22
from pathlib import Path
3-
from typing import List, Optional, Union
3+
from typing import List, Optional, Tuple, Union
44

5+
import numpy
56
import numpy as np
67
import xarray as xr
8+
from marshmallow import ValidationError
79

810
from bioimageio.core import load_resource_description
911
from bioimageio.core.prediction import predict
1012
from bioimageio.core.prediction_pipeline import create_prediction_pipeline
11-
from bioimageio.core.resource_io.nodes import Model, ResourceDescription, URI
13+
from bioimageio.core.resource_io.nodes import (
14+
ImplicitOutputShape,
15+
Model,
16+
ParametrizedInputShape,
17+
ResourceDescription,
18+
URI,
19+
)
1220
from bioimageio.spec.model.raw_nodes import WeightsFormat
1321
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
1422

@@ -45,32 +53,93 @@ def test_resource(
4553
tb: Optional = None
4654

4755
try:
48-
model = load_resource_description(model_rdf)
56+
rd = load_resource_description(model_rdf)
4957
except Exception as e:
5058
error = str(e)
5159
tb = traceback.format_tb(e.__traceback__)
5260
else:
53-
if isinstance(model, Model):
61+
if isinstance(rd, Model):
62+
model = rd
5463
try:
55-
prediction_pipeline = create_prediction_pipeline(
56-
bioimageio_model=model, devices=devices, weight_format=weight_format
57-
)
5864
inputs = [np.load(str(in_path)) for in_path in model.test_inputs]
59-
results = predict(prediction_pipeline, inputs)
60-
if isinstance(results, (np.ndarray, xr.DataArray)):
61-
results = [results]
62-
6365
expected = [np.load(str(out_path)) for out_path in model.test_outputs]
66+
67+
# check if test data shapes match their description
68+
input_shapes = [ipt.shape for ipt in inputs]
69+
output_shapes = [out.shape for out in expected]
70+
71+
def input_shape_is_valid(shape: Tuple[int, ...], shape_spec) -> bool:
72+
if isinstance(shape_spec, list):
73+
if shape != tuple(shape_spec):
74+
return False
75+
elif isinstance(shape_spec, ParametrizedInputShape):
76+
assert len(shape_spec.min) == len(shape_spec.step)
77+
if len(shape) != len(shape_spec.min):
78+
return False
79+
80+
valid_shape = numpy.array(shape_spec.min)
81+
step = numpy.array(shape_spec.step)
82+
if (step == 0).all():
83+
return shape == tuple(valid_shape)
84+
85+
shape = numpy.array(shape)
86+
while (shape <= valid_shape).all():
87+
if (shape == valid_shape).all():
88+
break
89+
90+
shape += step
91+
else:
92+
return False
93+
94+
else:
95+
raise TypeError(f"Encountered unexpected shape description of type {type(shape_spec)}")
96+
97+
return True
98+
99+
assert len(inputs) == len(model.inputs) # should be checked by validation
100+
input_shapes = {}
101+
for idx, (ipt, ipt_spec) in enumerate(zip(inputs, model.inputs)):
102+
if not input_shape_is_valid(ipt, ipt_spec.shape):
103+
raise ValidationError(
104+
f"Shape of test input {idx} '{ipt_spec.name}' does not match "
105+
f"input shape description: {ipt_spec.shape}"
106+
)
107+
input_shapes[ipt_spec.name] = ipt.shape
108+
109+
def output_shape_is_valid(shape: Tuple[int, ...], shape_spec) -> bool:
110+
if isinstance(shape_spec, list):
111+
return shape == tuple(shape_spec)
112+
elif isinstance(shape_spec, ImplicitOutputShape):
113+
ipt_shape = numpy.array(input_shapes[shape_spec.reference_tensor])
114+
scale = numpy.array(shape_spec.scale)
115+
offset = numpy.array(shape_spec.offset)
116+
exp_shape = numpy.round_(ipt_shape * scale) + 2 * offset
117+
118+
return shape == tuple(exp_shape)
119+
120+
assert len(expected) == len(model.outputs) # should be checked by validation
121+
for idx, (out, out_spec) in enumerate(zip(expected, model.outputs)):
122+
if not output_shape_is_valid(out, out_spec.shape):
123+
error = (error or "") + (
124+
f"Shape of test output {idx} '{out_spec.name}' does not match "
125+
f"output shape description: {out_spec.shape}.\n"
126+
)
127+
128+
with create_prediction_pipeline(
129+
bioimageio_model=model, devices=devices, weight_format=weight_format
130+
) as prediction_pipeline:
131+
results = predict(prediction_pipeline, inputs)
132+
64133
if len(results) != len(expected):
65-
error = (
134+
error = (error or "") + (
66135
f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}"
67136
)
68137
else:
69138
for res, exp in zip(results, expected):
70139
try:
71140
np.testing.assert_array_almost_equal(res, exp, decimal=decimal)
72141
except AssertionError as e:
73-
error = f"Output and expected output disagree:\n {e}"
142+
error = (error or "") + f"Output and expected output disagree:\n {e}"
74143
except Exception as e:
75144
error = str(e)
76145
tb = traceback.format_tb(e.__traceback__)

tests/conftest.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
import logging
2+
13
import pytest
4+
25
from bioimageio.core import export_resource_package
36

7+
logger = logging.getLogger(__name__)
8+
49
# test models for various frameworks
5-
torch_models = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
10+
torch_models = []
11+
torch_models_pre_3_10 = ["unet2d_fixed_shape", "unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
612
torchscript_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model"]
713
onnx_models = ["unet2d_multi_tensor", "unet2d_nuclei_broad_model", "hpa_densenet"]
8-
tensorflow1_models = ["FruNet_model"]
14+
tensorflow1_models = ["FruNet_model", "stardist"]
915
tensorflow2_models = []
1016
keras_models = ["FruNet_model"]
1117
tensorflow_js_models = ["FruNet_model"]
@@ -28,12 +34,23 @@
2834
"hpa_densenet": (
2935
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/hpa-densenet/rdf.yaml"
3036
),
37+
"stardist": (
38+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf.yaml"
39+
),
40+
"stardist_wrong_shape": (
41+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/stardist_example_model/rdf_wrong_shape.yaml"
42+
),
3143
}
3244

3345
try:
3446
import torch
47+
48+
torch_version = tuple(map(int, torch.__version__.split(".")[:2]))
49+
logger.warning(f"detected torch version {torch_version}.x")
3550
except ImportError:
3651
torch = None
52+
torch_version = None
53+
3754
skip_torch = torch is None
3855

3956
try:
@@ -64,6 +81,9 @@
6481
# load all model packages we need for testing
6582
load_model_packages = set()
6683
if not skip_torch:
84+
if torch_version < (3, 10):
85+
torch_models += torch_models_pre_3_10
86+
6787
load_model_packages |= set(torch_models + torchscript_models)
6888

6989
if not skip_onnx:
@@ -74,6 +94,7 @@
7494
load_model_packages |= set(tensorflow_js_models)
7595
if tf_major_version == 1:
7696
load_model_packages |= set(tensorflow1_models)
97+
load_model_packages.add("stardist_wrong_shape")
7798
elif tf_major_version == 2:
7899
load_model_packages |= set(tensorflow2_models)
79100

@@ -93,7 +114,7 @@ def pytest_configure():
93114
#
94115

95116
# written as model group to automatically skip on missing torch
96-
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model"])
117+
@pytest.fixture(params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_nuclei_broad_model"])
97118
def unet2d_nuclei_broad_model(request):
98119
return pytest.model_packages[request.param]
99120

@@ -104,6 +125,12 @@ def FruNet_model(request):
104125
return pytest.model_packages[request.param]
105126

106127

128+
# written as model group to automatically skip on missing tensorflow 1
129+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape"])
130+
def stardist_wrong_shape(request):
131+
return pytest.model_packages[request.param]
132+
133+
107134
@pytest.fixture(params=[] if skip_torch else torch_models)
108135
def any_torch_model(request):
109136
return pytest.model_packages[request.param]
@@ -146,11 +173,15 @@ def any_model(request):
146173
# temporary fixtures to test not with all, but only a manual selection of models
147174
# (models/functionality should be improved to get rid of this specific model group)
148175
#
149-
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"])
176+
@pytest.fixture(
177+
params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_nuclei_broad_model", "unet2d_fixed_shape"]
178+
)
150179
def unet2d_fixed_shape_or_not(request):
151180
return pytest.model_packages[request.param]
152181

153182

154-
@pytest.fixture(params=[] if skip_torch else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"])
183+
@pytest.fixture(
184+
params=[] if skip_torch or torch_version >= (3, 10) else ["unet2d_nuclei_broad_model", "unet2d_multi_tensor"]
185+
)
155186
def unet2d_multi_tensor_or_not(request):
156187
return pytest.model_packages[request.param]
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
def test_error_for_wrong_shape(stardist_wrong_shape):
2+
from bioimageio.core.resource_tests import test_model
3+
4+
summary = test_model(stardist_wrong_shape)
5+
assert (
6+
summary["error"]
7+
== "Shape of test input 0 'input' does not match input shape description: ParametrizedInputShape(min=[1, 16, 16, 1], step=[0, 16, 16, 0])"
8+
)

0 commit comments

Comments
 (0)