Skip to content

Commit 5926bab

Browse files
Add another test to ensure wrong shapes are caught by test_model
1 parent 8c78c56 commit 5926bab

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

tests/conftest.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
4343
"stardist_example_model/rdf_wrong_shape.yaml"
4444
),
45+
"stardist_wrong_shape2": (
46+
"https://raw.githubusercontent.com/bioimage-io/spec-bioimage-io/main/example_specs/models/"
47+
"stardist_example_model/rdf_wrong_shape2.yaml"
48+
),
4549
}
4650

4751
try:
@@ -94,6 +98,7 @@
9498
if tf_major_version == 1:
9599
load_model_packages |= set(tensorflow1_models)
96100
load_model_packages.add("stardist_wrong_shape")
101+
load_model_packages.add("stardist_wrong_shape2")
97102
elif tf_major_version == 2:
98103
load_model_packages |= set(tensorflow2_models)
99104

@@ -124,6 +129,12 @@ def stardist_wrong_shape(request):
124129
return pytest.model_packages[request.param]
125130

126131

132+
# written as model group to automatically skip on missing tensorflow 1
133+
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist_wrong_shape2"])
134+
def stardist_wrong_shape2(request):
135+
return pytest.model_packages[request.param]
136+
137+
127138
# written as model group to automatically skip on missing tensorflow 1
128139
@pytest.fixture(params=[] if skip_tensorflow or tf_major_version != 1 else ["stardist"])
129140
def stardist(request):
@@ -165,7 +176,7 @@ def any_tensorflow_js_model(request):
165176
# fixture to test with all models that should run in the current environment
166177
# we exclude stardist_wrong_shape here because it is not a valid model
167178
# and included only to test that validation for this model fails
168-
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape"})
179+
@pytest.fixture(params=load_model_packages - {"stardist_wrong_shape", "stardist_wrong_shape2"})
169180
def any_model(request):
170181
return pytest.model_packages[request.param]
171182

tests/test_resource_tests/test_test_model.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@ def test_error_for_wrong_shape(stardist_wrong_shape):
1010
assert summary["error"] == expected_error_message
1111

1212

13+
def test_error_for_wrong_shape2(stardist_wrong_shape2):
14+
from bioimageio.core.resource_tests import test_model
15+
16+
summary = test_model(stardist_wrong_shape2)
17+
expected_error_message = (
18+
"Shape (1, 512, 512, 1) of test input 0 'input' does not match input shape description: "
19+
"ParametrizedInputShape(min=[1, 16, 16, 1], step=[0, 17, 17, 0])."
20+
)
21+
assert summary["error"] == expected_error_message
22+
23+
1324
def test_test_model(any_model):
1425
from bioimageio.core.resource_tests import test_model
1526

0 commit comments

Comments
 (0)