Skip to content

Commit 5d1e2ce

Browse files
committed
expose test_description_in_conda_env
1 parent 00e6ba1 commit 5d1e2ce

File tree

3 files changed

+172
-0
lines changed

3 files changed

+172
-0
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,11 @@ The model specification and its validation tools can be found at <https://github
375375

376376
## Changelog
377377

378+
### 0.7.1 (to be released)
379+
380+
- New test function `bioimageio.core.test_description_in_conda_env` that uses conda
381+
in subprocesses to test a resource in a dedicated conda environment.
382+
378383
### 0.7.0
379384

380385
- breaking:

bioimageio/core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
stat_measures,
3333
tensor,
3434
)
35+
from ._dynamic_conda_env import test_description_in_conda_env
3536
from ._prediction_pipeline import PredictionPipeline, create_prediction_pipeline
3637
from ._resource_tests import (
3738
enable_determinism,
@@ -104,6 +105,7 @@
104105
"Stat",
105106
"tensor",
106107
"Tensor",
108+
"test_description_in_conda_env",
107109
"test_description",
108110
"test_model",
109111
"test_resource",
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
import subprocess
2+
from hashlib import sha256
3+
from io import StringIO
4+
from pathlib import Path
5+
from tempfile import TemporaryDirectory
6+
from typing import (
7+
Callable,
8+
List,
9+
Literal,
10+
Optional,
11+
Sequence,
12+
assert_never,
13+
)
14+
15+
from loguru import logger
16+
from typing_extensions import get_args
17+
18+
from bioimageio.spec import (
19+
BioimageioCondaEnv,
20+
ValidationSummary,
21+
get_conda_env,
22+
load_description,
23+
)
24+
from bioimageio.spec._internal.io import is_yaml_value
25+
from bioimageio.spec._internal.io_utils import write_yaml
26+
from bioimageio.spec.common import PermissiveFileSource
27+
from bioimageio.spec.model import v0_4, v0_5
28+
from bioimageio.spec.model.v0_5 import WeightsFormat
29+
30+
31+
def default_run_command(args: Sequence[str]):
32+
logger.info("running '{}'...", " ".join(args))
33+
_ = subprocess.run(args, shell=True, text=True, check=True)
34+
35+
36+
def test_description_in_conda_env(
37+
source: PermissiveFileSource,
38+
*,
39+
weight_format: Optional[WeightsFormat] = None,
40+
conda_env: Optional[BioimageioCondaEnv] = None,
41+
devices: Optional[List[str]] = None,
42+
absolute_tolerance: float = 1.5e-4,
43+
relative_tolerance: float = 1e-4,
44+
determinism: Literal["seed_only", "full"] = "seed_only",
45+
run_command: Callable[[Sequence[str]], None] = default_run_command,
46+
) -> ValidationSummary:
47+
"""Run test_model in a dedicated conda env
48+
49+
Args:
50+
source: Path or URL to model description.
51+
weight_format: Weight format to test.
52+
Default: All weight formats present in **source**.
53+
conda_env: conda environment including bioimageio.core dependency.
54+
Default: Use `bioimageio.spec.get_conda_env` to obtain a model weight
55+
specific conda environment.
56+
devices: Devices to test with, e.g. 'cpu', 'cuda'.
57+
Default (may be weight format dependent): ['cuda'] if available, ['cpu'] otherwise.
58+
absolute_tolerance: Maximum absolute tolerance of reproduced output tensors.
59+
relative_tolerance: Maximum relative tolerance of reproduced output tensors.
60+
determinism: Modes to improve reproducibility of test outputs.
61+
run_command: Function to execute terminal commands.
62+
"""
63+
64+
try:
65+
run_command(["which", "conda"])
66+
except Exception as e:
67+
raise RuntimeError("Conda not available") from e
68+
69+
descr = load_description(source)
70+
if not isinstance(descr, (v0_4.ModelDescr, v0_5.ModelDescr)):
71+
raise NotImplementedError("Not yet implemented for non-model resources")
72+
73+
if weight_format is None:
74+
all_present_wfs = [
75+
wf for wf in get_args(WeightsFormat) if getattr(descr.weights, wf)
76+
]
77+
ignore_wfs = [wf for wf in all_present_wfs if wf in ["tensorflow_js"]]
78+
logger.info(
79+
"Found weight formats {}. Start testing all{}...",
80+
all_present_wfs,
81+
f" (except: {', '.join(ignore_wfs)}) " if ignore_wfs else "",
82+
)
83+
summary = test_description_in_env(
84+
source,
85+
weight_format=all_present_wfs[0],
86+
devices=devices,
87+
absolute_tolerance=absolute_tolerance,
88+
relative_tolerance=relative_tolerance,
89+
determinism=determinism,
90+
)
91+
for wf in all_present_wfs[1:]:
92+
additional_summary = test_description_in_env(
93+
source,
94+
weight_format=all_present_wfs[0],
95+
devices=devices,
96+
absolute_tolerance=absolute_tolerance,
97+
relative_tolerance=relative_tolerance,
98+
determinism=determinism,
99+
)
100+
for d in additional_summary.details:
101+
# TODO: filter reduntant details; group details
102+
summary.add_detail(d)
103+
return summary
104+
105+
if weight_format == "pytorch_state_dict":
106+
wf = descr.weights.pytorch_state_dict
107+
elif weight_format == "torchscript":
108+
wf = descr.weights.torchscript
109+
elif weight_format == "keras_hdf5":
110+
wf = descr.weights.keras_hdf5
111+
elif weight_format == "onnx":
112+
wf = descr.weights.onnx
113+
elif weight_format == "tensorflow_saved_model_bundle":
114+
wf = descr.weights.tensorflow_saved_model_bundle
115+
elif weight_format == "tensorflow_js":
116+
raise RuntimeError(
117+
"testing 'tensorflow_js' is not supported by bioimageio.core"
118+
)
119+
else:
120+
assert_never(weight_format)
121+
122+
assert wf is not None
123+
if conda_env is None:
124+
conda_env = get_conda_env(entry=wf)
125+
126+
# remove name as we crate a name based on the env description hash value
127+
conda_env.name = None
128+
129+
dumped_env = conda_env.model_dump(mode="json", exclude_none=True)
130+
if not is_yaml_value(dumped_env):
131+
raise ValueError(f"Failed to dump conda env to valid YAML {conda_env}")
132+
133+
env_io = StringIO()
134+
write_yaml(dumped_env, file=env_io)
135+
encoded_env = env_io.getvalue().encode()
136+
env_name = sha256(encoded_env).hexdigest()
137+
138+
with TemporaryDirectory() as _d:
139+
folder = Path(_d)
140+
try:
141+
run_command(["conda", "activate", env_name])
142+
except Exception:
143+
path = folder / "env.yaml"
144+
_ = path.write_bytes(encoded_env)
145+
146+
run_command(
147+
["conda", "env", "create", "--file", str(path), "--name", env_name]
148+
)
149+
run_command(["conda", "activate", env_name])
150+
151+
summary_path = folder / "summary.json"
152+
run_command(
153+
[
154+
"conda",
155+
"run",
156+
"-n",
157+
env_name,
158+
"bioimageio",
159+
"test",
160+
str(source),
161+
"--summary-path",
162+
str(summary_path),
163+
]
164+
)
165+
return ValidationSummary.model_validate_json(summary_path.read_bytes())

0 commit comments

Comments
 (0)