Skip to content

Commit 4dccc01

Browse files
committed
use test_resource fully in test_model
1 parent 06ea47c commit 4dccc01

File tree

1 file changed

+19
-32
lines changed

1 file changed

+19
-32
lines changed

bioimageio/core/resource_tests.py

Lines changed: 19 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,38 +34,10 @@ def test_model(
3434
devices: Optional[List[str]] = None,
3535
decimal: int = 4,
3636
) -> List[TestSummary]:
37-
"""Test whether the test output(s) of a model can be reproduced.
38-
39-
Returns: summary dict with keys: name, status, error, traceback, bioimageio_spec_version, bioimageio_core_version
40-
"""
41-
# todo: reuse more of 'test_resource'
42-
tb = None
43-
try:
44-
model = load_resource_description(
45-
model_rdf, weights_priority_order=None if weight_format is None else [weight_format]
46-
)
47-
except Exception as e:
48-
model = None
49-
error = str(e)
50-
tb = traceback.format_tb(e.__traceback__)
51-
else:
52-
error = None
53-
54-
if isinstance(model, Model):
55-
return test_resource(model, weight_format=weight_format, devices=devices, decimal=decimal)
56-
else:
57-
error = error or f"Expected RDF type Model, got {type(model)} instead."
58-
59-
return [
60-
dict(
61-
name="reproduced test outputs from test inputs",
62-
status="failed",
63-
error=error,
64-
traceback=tb,
65-
bioimageio_spec_version=bioimageio_spec_version,
66-
bioimageio_core_version=bioimageio_core_version,
67-
)
68-
]
37+
"""Test whether the test output(s) of a model can be reproduced."""
38+
return test_resource(
39+
model_rdf, weight_format=weight_format, devices=devices, decimal=decimal, expected_type="model"
40+
)
6941

7042

7143
def _validate_input_shape(shape: Tuple[int, ...], shape_spec) -> bool:
@@ -248,12 +220,24 @@ def _test_load_resource(
248220
return rd, load_summary
249221

250222

223+
def _test_expected_resource_type(rd: ResourceDescription, expected_type: str) -> TestSummary:
224+
has_expected_type = rd.type == expected_type
225+
return dict(
226+
name="has expected resource type",
227+
status="passed" if has_expected_type else "failed",
228+
error=f"expected type {expected_type}, found {rd.type}",
229+
traceback=None,
230+
source_name=rd.id if hasattr(rd, "id") else rd.name,
231+
)
232+
233+
251234
def test_resource(
252235
rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str],
253236
*,
254237
weight_format: Optional[WeightsFormat] = None,
255238
devices: Optional[List[str]] = None,
256239
decimal: int = 4,
240+
expected_type: Optional[str] = None,
257241
) -> List[TestSummary]:
258242
"""Test RDF dynamically
259243
@@ -262,6 +246,9 @@ def test_resource(
262246
rd, load_test = _test_load_resource(rdf, weight_format)
263247
tests: List[TestSummary] = [load_test]
264248
if rd is not None:
249+
if expected_type is not None:
250+
tests.append(_test_expected_resource_type(rd, expected_type))
251+
265252
tests.append(_test_resource_urls(rd))
266253

267254
if isinstance(rd, Model):

0 commit comments

Comments
 (0)