Skip to content

Commit a96a346

Browse files
Merge pull request #130 from bioimage-io/add_test_resource
Add test_resource
2 parents 7c94582 + b2efff4 commit a96a346

File tree

7 files changed

+150
-50
lines changed

7 files changed

+150
-50
lines changed

bioimageio/core/__main__.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import enum
12
import json
23
import os
34
from glob import glob
@@ -7,8 +8,14 @@
78

89
import typer
910

10-
from bioimageio.core import __version__, prediction, commands
11+
from bioimageio.core import __version__, prediction, commands, resource_tests
1112
from bioimageio.spec.__main__ import app
13+
from bioimageio.spec.model.raw_nodes import WeightsFormat
14+
15+
try:
16+
from typing import get_args
17+
except ImportError:
18+
from typing_extensions import get_args # type: ignore
1219

1320
try:
1421
from bioimageio.core.weight_converter import torch as torch_converter
@@ -40,28 +47,57 @@ def package(
4047

4148
# if we want to use something like "choice" for the weight formats, we need to use an enum, see:
4249
# https://github.com/tiangolo/typer/issues/182
50+
WeightFormatEnum = enum.Enum("WeightFormatEnum", get_args(WeightsFormat))
51+
52+
4353
@app.command()
4454
def test_model(
4555
model_rdf: str = typer.Argument(
4656
..., help="Path or URL to the model resource description file (rdf.yaml) or zipped model."
4757
),
48-
weight_format: Optional[str] = typer.Argument(None, help="The weight format to use."),
58+
weight_format: Optional[WeightFormatEnum] = typer.Argument(None, help="The weight format to use."),
4959
devices: Optional[List[str]] = typer.Argument(None, help="Devices for running the model."),
5060
decimal: int = typer.Argument(4, help="The test precision."),
5161
) -> int:
5262
# this is a weird typer bug: default devices are empty tuple although they should be None
5363
if len(devices) == 0:
5464
devices = None
55-
test_passed = prediction.test_model(model_rdf, weight_format=weight_format, devices=devices, decimal=decimal)
56-
if test_passed:
65+
summary = resource_tests.test_model(model_rdf, weight_format=weight_format, devices=devices, decimal=decimal)
66+
if summary["error"] is None:
5767
print(f"Model test for {model_rdf} has passed.")
68+
return 0
5869
else:
5970
print(f"Model test for {model_rdf} has FAILED!")
60-
ret_code = 0 if test_passed else 1
61-
return ret_code
71+
print(summary)
72+
return 1
73+
74+
75+
test_model.__doc__ = resource_tests.test_model.__doc__
76+
77+
78+
@app.command()
79+
def test_resource(
80+
rdf: str = typer.Argument(
81+
..., help="Path or URL to the resource description file (rdf.yaml) or zipped resource package."
82+
),
83+
weight_format: Optional[WeightFormatEnum] = typer.Argument(None, help="(for model only) The weight format to use."),
84+
devices: Optional[List[str]] = typer.Argument(None, help="(for model only) Devices for running the model."),
85+
decimal: int = typer.Argument(4, help="(for model only) The test precision."),
86+
) -> int:
87+
# this is a weird typer bug: default devices are empty tuple although they should be None
88+
if len(devices) == 0:
89+
devices = None
90+
summary = resource_tests.test_resource(rdf, weight_format=weight_format, devices=devices, decimal=decimal)
91+
if summary["error"] is None:
92+
print(f"Resource test for {rdf} has passed.")
93+
return 0
94+
else:
95+
print(f"Resource test for {rdf} has FAILED!")
96+
print(summary)
97+
return 1
6298

6399

64-
test_model.__doc__ = prediction.test_model.__doc__
100+
test_resource.__doc__ = resource_tests.test_resource.__doc__
65101

66102

67103
@app.command()

bioimageio/core/prediction.py

Lines changed: 2 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import collections
22
import os
3-
import warnings
43
from copy import deepcopy
54
from itertools import product
65
from pathlib import Path
@@ -9,19 +8,16 @@
98
import imageio
109
import numpy as np
1110
import xarray as xr
11+
from tqdm import tqdm
1212

1313
from bioimageio.core import load_resource_description
14-
from bioimageio.core.resource_io.nodes import InputTensor, Model, OutputTensor
1514
from bioimageio.core.prediction_pipeline import PredictionPipeline, create_prediction_pipeline
16-
from tqdm import tqdm
15+
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, InputTensor, Model, OutputTensor
1716

1817

1918
#
2019
# utility functions for prediction
2120
#
22-
from bioimageio.core.resource_io.nodes import ImplicitOutputShape, URI
23-
24-
2521
def require_axes(im, axes):
2622
is_volume = "z" in axes
2723
# we assume images / volumes are loaded as one of
@@ -474,32 +470,3 @@ def predict_images(
474470
outp = [outp]
475471

476472
_predict_sample(prediction_pipeline, inp, outp, padding, tiling)
477-
478-
479-
def test_model(model_rdf: Union[URI, Path, str], weight_format=None, devices=None, decimal=4):
480-
"""Test whether the test output(s) of a model can be reproduced.
481-
482-
Returns True if the test passes, otherwise returns False and issues a warning.
483-
"""
484-
model = load_resource_description(model_rdf)
485-
assert isinstance(model, Model)
486-
prediction_pipeline = create_prediction_pipeline(
487-
bioimageio_model=model, devices=devices, weight_format=weight_format
488-
)
489-
inputs = [np.load(str(in_path)) for in_path in model.test_inputs]
490-
results = predict(prediction_pipeline, inputs)
491-
if isinstance(results, (np.ndarray, xr.DataArray)):
492-
results = [results]
493-
494-
expected = [np.load(str(out_path)) for out_path in model.test_outputs]
495-
if len(results) != len(expected):
496-
warnings.warn(f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}")
497-
return False
498-
499-
for res, exp in zip(results, expected):
500-
try:
501-
np.testing.assert_array_almost_equal(res, exp, decimal=decimal)
502-
except AssertionError as e:
503-
warnings.warn(f"Output and expected output disagree:\n {e}")
504-
return False
505-
return True

bioimageio/core/prediction_pipeline/_model_adapters/_model_adapter.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,9 @@ def create_model_adapter(
5858
adapter_cls = _get_model_adapter(weight)
5959
return adapter_cls(bioimageio_model=bioimageio_model, devices=devices)
6060

61-
raise NotImplementedError(f"No supported weight_formats in {spec.weights.keys()}")
61+
raise RuntimeError(
62+
f"weight format {weight_format} not among weight formats listed in model: {list(spec.weights.keys())}"
63+
)
6264

6365

6466
def _get_model_adapter(weight_format: str) -> Type[ModelAdapter]:

bioimageio/core/resource_io/io_.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ def extract_resource_package(
4040
if (package_path / "rdf.yaml").exists():
4141
download = None
4242
else:
43-
download, header = urlretrieve(str(root))
43+
try:
44+
download, header = urlretrieve(str(root))
45+
except Exception as e:
46+
raise RuntimeError(f"Failed to download {str(root)} ({e})")
4447

4548
local_source = download
4649
else:
@@ -91,7 +94,7 @@ def _replace_relative_paths_for_remote_source(
9194

9295

9396
def load_raw_resource_description(
94-
source: Union[dict, os.PathLike, IO, str, bytes, raw_nodes.URI]
97+
source: Union[dict, os.PathLike, IO, str, bytes, raw_nodes.URI, RawResourceDescription]
9598
) -> RawResourceDescription:
9699
"""load a raw python representation from a BioImage.IO resource description file (RDF).
97100
Use `load_resource_description` for a more convenient representation.
@@ -102,13 +105,16 @@ def load_raw_resource_description(
102105
Returns:
103106
raw BioImage.IO resource
104107
"""
108+
if isinstance(source, RawResourceDescription):
109+
return source
110+
105111
raw_rd = spec.load_raw_resource_description(source, update_to_current_format=True)
106112
raw_rd = _replace_relative_paths_for_remote_source(raw_rd, raw_rd.root_path)
107113
return raw_rd
108114

109115

110116
def load_resource_description(
111-
source: Union[RawResourceDescription, os.PathLike, str, dict, raw_nodes.URI],
117+
source: Union[RawResourceDescription, ResourceDescription, os.PathLike, str, dict, raw_nodes.URI],
112118
*,
113119
weights_priority_order: Optional[Sequence[str]] = None, # model only
114120
) -> ResourceDescription:
@@ -123,6 +129,9 @@ def load_resource_description(
123129
BioImage.IO resource
124130
"""
125131
source = deepcopy(source)
132+
if isinstance(source, ResourceDescription):
133+
return source
134+
126135
raw_rd = load_raw_resource_description(source)
127136

128137
if weights_priority_order is not None:

bioimageio/core/resource_io/utils.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,9 +276,8 @@ def _download_uri_to_local_path(uri: raw_nodes.URI) -> pathlib.Path:
276276
local_path.parent.mkdir(parents=True, exist_ok=True)
277277
try:
278278
urlretrieve(str(uri), str(local_path))
279-
except Exception:
280-
logging.getLogger("download").error("Failed to download %s", uri)
281-
raise
279+
except Exception as e:
280+
raise RuntimeError(f"Failed to download {uri} ({e})")
282281

283282
return local_path
284283

bioimageio/core/resource_tests.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import traceback
2+
import warnings
3+
from pathlib import Path
4+
from typing import List, Optional, Union
5+
6+
import numpy as np
7+
import xarray as xr
8+
9+
from bioimageio.core import load_resource_description
10+
from bioimageio.core.prediction import predict
11+
from bioimageio.core.prediction_pipeline import create_prediction_pipeline
12+
from bioimageio.core.resource_io.nodes import Model, ResourceDescription, URI
13+
from bioimageio.spec.model.raw_nodes import WeightsFormat
14+
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
15+
16+
17+
def test_model(
18+
model_rdf: Union[URI, Path, str],
19+
weight_format: Optional[WeightsFormat] = None,
20+
devices: Optional[List[str]] = None,
21+
decimal: int = 4,
22+
) -> dict:
23+
"""Test whether the test output(s) of a model can be reproduced.
24+
25+
Returns summary dict with "error" and "traceback" key; summary["error"] is None if no errors were encountered.
26+
"""
27+
model = load_resource_description(model_rdf)
28+
if isinstance(model, Model):
29+
return test_resource(model, weight_format=weight_format, devices=devices, decimal=decimal)
30+
else:
31+
return {"error": f"Expected RDF type Model, got {type(model)} instead.", "traceback": None}
32+
33+
34+
def test_resource(
35+
model_rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str],
36+
*,
37+
weight_format: Optional[WeightsFormat] = None,
38+
devices: Optional[List[str]] = None,
39+
decimal: int = 4,
40+
):
41+
"""Test RDF dynamically
42+
43+
Returns summary dict with "error" and "traceback" key; summary["error"] is None if no errors were encountered.
44+
"""
45+
error: Optional[str] = None
46+
tb: Optional = None
47+
48+
try:
49+
model = load_resource_description(model_rdf)
50+
except Exception as e:
51+
error = str(e)
52+
tb = traceback.format_tb(e.__traceback__)
53+
else:
54+
if isinstance(model, Model):
55+
try:
56+
prediction_pipeline = create_prediction_pipeline(
57+
bioimageio_model=model, devices=devices, weight_format=weight_format
58+
)
59+
inputs = [np.load(str(in_path)) for in_path in model.test_inputs]
60+
results = predict(prediction_pipeline, inputs)
61+
if isinstance(results, (np.ndarray, xr.DataArray)):
62+
results = [results]
63+
64+
expected = [np.load(str(out_path)) for out_path in model.test_outputs]
65+
if len(results) != len(expected):
66+
error = (
67+
f"Number of outputs and number of expected outputs disagree: {len(results)} != {len(expected)}"
68+
)
69+
else:
70+
for res, exp in zip(results, expected):
71+
try:
72+
np.testing.assert_array_almost_equal(res, exp, decimal=decimal)
73+
except AssertionError as e:
74+
error = f"Output and expected output disagree:\n {e}"
75+
except Exception as e:
76+
error = str(e)
77+
tb = traceback.format_tb(e.__traceback__)
78+
79+
# todo: add tests for non-model resources
80+
81+
return {"error": error, "traceback": tb}

tests/test_prediction.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,17 @@
99

1010

1111
def test_test_model(unet2d_nuclei_broad_model):
12-
from bioimageio.core.prediction import test_model
12+
from bioimageio.core.resource_tests import test_model
1313

1414
assert test_model(unet2d_nuclei_broad_model)
1515

1616

17+
def test_test_resource(unet2d_nuclei_broad_model):
18+
from bioimageio.core.resource_tests import test_resource
19+
20+
assert test_resource(unet2d_nuclei_broad_model)
21+
22+
1723
def test_predict_image(unet2d_fixed_shape_or_not, tmpdir):
1824
any_model = unet2d_fixed_shape_or_not # todo: replace 'unet2d_fixed_shape_or_not' with 'any_model'
1925
from bioimageio.core.prediction import predict_image

0 commit comments

Comments
 (0)