Skip to content

Commit d1c3cbb

Browse files
Merge branch 'fix-build-spec' into code-examples
2 parents 4eb5020 + 9078f77 commit d1c3cbb

File tree

8 files changed

+89
-148
lines changed

8 files changed

+89
-148
lines changed

bioimageio/core/build_spec/build_model.py

Lines changed: 13 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,7 @@ def _get_weights(weight_uri, weight_type, source, root, **kwargs):
100100

101101
elif weight_type == "onnx":
102102
weights = model_spec.raw_nodes.OnnxWeightsEntry(
103-
source=weight_source,
104-
sha256=weight_hash,
105-
opset_version=kwargs.get("opset_version", 12),
106-
**attachments
103+
source=weight_source, sha256=weight_hash, opset_version=kwargs.get("opset_version", 12), **attachments
107104
)
108105
language = None
109106
framework = None
@@ -197,9 +194,7 @@ def _get_input_tensor(path, name, step, min_shape, data_range, axes, preprocessi
197194

198195
kwargs = {}
199196
if preprocessing is not None:
200-
kwargs["preprocessing"] = [
201-
{"name": k, "kwargs": v} for k, v in preprocessing.items()
202-
]
197+
kwargs["preprocessing"] = [{"name": k, "kwargs": v} for k, v in preprocessing.items()]
203198

204199
inputs = model_spec.raw_nodes.InputTensor(
205200
name="input" if name is None else name,
@@ -229,9 +224,7 @@ def _get_output_tensor(path, name, reference_tensor, scale, offset, axes, data_r
229224

230225
kwargs = {}
231226
if postprocessing is not None:
232-
kwargs["postprocessing"] = [
233-
{"name": k, "kwargs": v} for k, v in postprocessing.items()
234-
]
227+
kwargs["postprocessing"] = [{"name": k, "kwargs": v} for k, v in postprocessing.items()]
235228
if halo is not None:
236229
kwargs["halo"] = halo
237230

@@ -389,8 +382,9 @@ def build_model(
389382

390383
inputs = [
391384
_get_input_tensor(test_in, name, step, min_shape, axes, data_range, preproc)
392-
for test_in, name, step, min_shape, axes, data_range, preproc in
393-
zip(test_inputs, input_name, input_step, input_min_shape, input_axes, input_data_range, preprocessing)
385+
for test_in, name, step, min_shape, axes, data_range, preproc in zip(
386+
test_inputs, input_name, input_step, input_min_shape, input_axes, input_data_range, preprocessing
387+
)
394388
]
395389

396390
n_outputs = len(test_outputs)
@@ -405,8 +399,7 @@ def build_model(
405399

406400
outputs = [
407401
_get_output_tensor(test_out, name, reference, scale, offset, axes, data_range, postproc, hal)
408-
for test_out, name, reference, scale, offset, axes, data_range, postproc, hal in
409-
zip(
402+
for test_out, name, reference, scale, offset, axes, data_range, postproc, hal in zip(
410403
test_outputs,
411404
output_name,
412405
output_reference,
@@ -415,7 +408,7 @@ def build_model(
415408
output_axes,
416409
output_data_range,
417410
postprocessing,
418-
halo
411+
halo,
419412
)
420413
]
421414

@@ -449,7 +442,7 @@ def build_model(
449442
"source": source,
450443
"sha256": source_hash,
451444
"kwargs": model_kwargs,
452-
"links": links
445+
"links": links,
453446
}
454447
kwargs = {k: v for k, v in optional_kwargs.items() if v is not None}
455448
if dependencies is not None:
@@ -495,13 +488,13 @@ def add_weights(
495488
weight_uri: Union[str, Path],
496489
weight_type: Optional[str] = None,
497490
output_path: Optional[Union[str, Path]] = None,
498-
**weight_kwargs
491+
**weight_kwargs,
499492
):
500493
"""Add weight entry to bioimage.io model."""
501494
# we need to patss the weight path as abs path to avoid confusion with different root directories
502-
new_weights = _get_weights(
503-
Path(weight_uri).absolute(), weight_type, source=None, root=Path("."), **weight_kwargs
504-
)[0]
495+
new_weights = _get_weights(Path(weight_uri).absolute(), weight_type, source=None, root=Path("."), **weight_kwargs)[
496+
0
497+
]
505498
model.weights.update(new_weights)
506499
if output_path is not None:
507500
model_package = export_resource_package(model, output_path=output_path)

bioimageio/core/resource_io/common.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

bioimageio/core/resource_io/io_.py

Lines changed: 67 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,48 +1,97 @@
11
import os
22
import pathlib
33
import warnings
4+
import zipfile
45
from copy import deepcopy
5-
from typing import Any, Dict, Optional, Sequence, Tuple, Union
6+
from typing import Dict, IO, Optional, Sequence, Tuple, Union
67
from zipfile import ZIP_DEFLATED, ZipFile
78

8-
from marshmallow import ValidationError, missing
9+
from marshmallow import missing
910

1011
from bioimageio import spec
1112
from bioimageio.core.resource_io.nodes import ResourceDescription
13+
from bioimageio.spec.io_ import resolve_rdf_source
1214
from bioimageio.spec.shared import raw_nodes
13-
from bioimageio.spec.shared.common import get_class_name_from_type
15+
from bioimageio.spec.shared.common import BIOIMAGEIO_CACHE_PATH, get_class_name_from_type
1416
from bioimageio.spec.shared.raw_nodes import ResourceDescription as RawResourceDescription
1517
from bioimageio.spec.shared.utils import PathToRemoteUriTransformer
1618
from . import nodes
17-
from .common import BIOIMAGEIO_CACHE_PATH, yaml
18-
from .utils import _download_uri_to_local_path, resolve_local_uri, resolve_raw_resource_description, resolve_uri
19+
from .utils import resolve_raw_resource_description, resolve_uri
1920

20-
21-
ROOT_PATH = "root_path"
2221
serialize_raw_resource_description = spec.io_.serialize_raw_resource_description
2322
save_raw_resource_description = spec.io_.save_raw_resource_description
2423

2524

25+
def extract_resource_package(
26+
source: Union[os.PathLike, IO, str, bytes, raw_nodes.URI]
27+
) -> Tuple[dict, str, pathlib.Path]:
28+
"""extract a zip source to BIOIMAGEIO_CACHE_PATH"""
29+
source, source_name, root = resolve_rdf_source(source)
30+
if isinstance(root, bytes):
31+
raise NotImplementedError("package source was bytes")
32+
33+
cache_folder = BIOIMAGEIO_CACHE_PATH / "extracted_packages"
34+
cache_folder.mkdir(exist_ok=True, parents=True)
35+
36+
if isinstance(root, raw_nodes.URI):
37+
from urllib.request import urlretrieve
38+
39+
package_path = cache_folder / root.scheme / root.authority / root.path.strip("/") / root.query
40+
if (package_path / "rdf.yaml").exists():
41+
download = None
42+
else:
43+
download, header = urlretrieve(str(root))
44+
45+
local_source = download
46+
else:
47+
download = None
48+
local_source = root
49+
package_path = cache_folder / root.relative_to(list(root.parents)[-1])
50+
51+
if local_source is not None:
52+
with zipfile.ZipFile(local_source) as zf:
53+
zf.extractall(package_path)
54+
55+
if not (package_path / "rdf.yaml").exists():
56+
raise FileNotFoundError(f"missing 'rdf.yaml' in {root} extracted from {download}")
57+
58+
if download is not None:
59+
try:
60+
os.remove(download)
61+
except Exception as e:
62+
warnings.warn(f"Could not remove download {download} due to {e}")
63+
64+
assert isinstance(package_path, pathlib.Path)
65+
return source, source_name, package_path
66+
67+
2668
def _replace_relative_paths_for_remote_source(
27-
raw_rd: RawResourceDescription, source: Union[Any, str, raw_nodes.URI]
69+
raw_rd: RawResourceDescription, root: Union[pathlib.Path, raw_nodes.URI, bytes]
2870
) -> RawResourceDescription:
29-
if isinstance(source, raw_nodes.URI) or isinstance(source, str) and source.startswith("http"):
71+
if isinstance(root, raw_nodes.URI):
3072
# for a remote source relative paths are invalid; replace all relative file paths in source with URLs
31-
if isinstance(source, str):
32-
source = raw_nodes.URI(source)
33-
3473
warnings.warn(
35-
f"changing file paths in RDF to URIs due to a remote {source.scheme} source "
74+
f"changing file paths in RDF to URIs due to a remote {root.scheme} source "
3675
"(may result in an invalid node)"
3776
)
38-
raw_rd = PathToRemoteUriTransformer(remote_source=source).transform(raw_rd)
39-
raw_rd.root_path = pathlib.Path() # root_path cannot be URI
77+
raw_rd = PathToRemoteUriTransformer(remote_source=root).transform(raw_rd)
78+
root_path = pathlib.Path() # root_path cannot be URI
79+
elif isinstance(root, pathlib.Path):
80+
if zipfile.is_zipfile(root):
81+
_, _, root_path = extract_resource_package(root)
82+
else:
83+
root_path = root
84+
elif isinstance(root, bytes):
85+
raise NotImplementedError("root as bytes (io)")
86+
else:
87+
raise TypeError(root)
4088

89+
raw_rd.root_path = root_path
4190
return raw_rd
4291

4392

4493
def load_raw_resource_description(
45-
source: Union[os.PathLike, str, dict, raw_nodes.URI, RawResourceDescription]
94+
source: Union[dict, os.PathLike, IO, str, bytes, raw_nodes.URI]
4695
) -> RawResourceDescription:
4796
"""load a raw python representation from a BioImage.IO resource description file (RDF).
4897
Use `load_resource_description` for a more convenient representation.
@@ -53,12 +102,8 @@ def load_raw_resource_description(
53102
Returns:
54103
raw BioImage.IO resource
55104
"""
56-
if isinstance(source, RawResourceDescription):
57-
return source
58-
59-
data, type_ = resolve_rdf_source_and_type(source)
60-
raw_rd = spec.load_raw_resource_description(data, update_to_current_format=True)
61-
raw_rd = _replace_relative_paths_for_remote_source(raw_rd, source)
105+
raw_rd = spec.load_raw_resource_description(source, update_to_current_format=True)
106+
raw_rd = _replace_relative_paths_for_remote_source(raw_rd, raw_rd.root_path)
62107
return raw_rd
63108

64109

@@ -186,26 +231,6 @@ def _get_tmp_package_path(raw_rd: RawResourceDescription, weights_priority_order
186231
return package_path
187232

188233

189-
def extract_resource_package(source: Union[os.PathLike, str, raw_nodes.URI]) -> pathlib.Path:
190-
"""extract a zip source to BIOIMAGEIO_CACHE_PATH"""
191-
local_source = resolve_uri(source)
192-
assert isinstance(local_source, pathlib.Path)
193-
cache_folder = BIOIMAGEIO_CACHE_PATH / "extracted_packages"
194-
cache_folder.mkdir(exist_ok=True, parents=True)
195-
package_path = cache_folder / f"{local_source.stem}"
196-
with ZipFile(local_source) as zf:
197-
zf.extractall(package_path)
198-
199-
for rdf_name in ["rdf.yaml", "model.yaml", "rdf.yml", "model.yml"]:
200-
rdf_path = package_path / rdf_name
201-
if rdf_path.exists():
202-
break
203-
else:
204-
raise FileNotFoundError(local_source / "rdf.yaml")
205-
206-
return rdf_path
207-
208-
209234
def make_zip(
210235
path: os.PathLike, content: Dict[str, Union[str, pathlib.Path]], *, compression: int, compression_level: int
211236
) -> None:
@@ -225,53 +250,3 @@ def make_zip(
225250
myzip.writestr(arc_name, file_or_str_content)
226251
else:
227252
myzip.write(file_or_str_content, arcname=arc_name)
228-
229-
230-
def resolve_rdf_source_and_type(source: Union[os.PathLike, str, dict, raw_nodes.URI]) -> Tuple[dict, str]:
231-
if isinstance(source, dict):
232-
data = source
233-
if ROOT_PATH not in data:
234-
data[ROOT_PATH] = pathlib.Path()
235-
else:
236-
data = get_dict_from_yaml_source(source)
237-
238-
type_ = data.get("type", "model") # todo: remove default 'model' type
239-
240-
return data, type_
241-
242-
243-
def get_dict_from_yaml_source(source: Union[os.PathLike, str, raw_nodes.URI, dict]) -> dict:
244-
if isinstance(source, dict):
245-
if ROOT_PATH not in source:
246-
source[ROOT_PATH] = pathlib.Path()
247-
248-
return source
249-
elif isinstance(source, (str, os.PathLike, raw_nodes.URI)):
250-
source = resolve_local_uri(source, pathlib.Path())
251-
else:
252-
raise TypeError(source)
253-
254-
if isinstance(source, raw_nodes.URI): # remote uri
255-
local_source = _download_uri_to_local_path(source)
256-
root_path = pathlib.Path()
257-
else:
258-
local_source = source
259-
root_path = source.parent
260-
261-
assert isinstance(local_source, pathlib.Path)
262-
if local_source.suffix == ".zip":
263-
local_source = extract_resource_package(local_source)
264-
root_path = local_source.parent
265-
266-
if local_source.suffix == ".yml":
267-
warnings.warn(
268-
"suffix '.yml' is not recommended and will raise a ValidationError in the future. Use '.yaml' instead "
269-
"(https://yaml.org/faq.html)"
270-
)
271-
elif local_source.suffix != ".yaml":
272-
raise ValidationError(f"invalid suffix {local_source.suffix} for source {source}")
273-
274-
data = yaml.load(local_source)
275-
assert isinstance(data, dict)
276-
data[ROOT_PATH] = root_path
277-
return data

bioimageio/core/resource_io/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@
1414
from marshmallow import ValidationError
1515

1616
from bioimageio.spec.shared import fields, raw_nodes
17+
from bioimageio.spec.shared.common import BIOIMAGEIO_CACHE_PATH
1718
from bioimageio.spec.shared.utils import GenericRawNode, GenericRawRD, NodeTransformer, NodeVisitor
1819
from . import nodes
19-
from .common import BIOIMAGEIO_CACHE_PATH
2020

2121
GenericResolvedNode = typing.TypeVar("GenericResolvedNode", bound=nodes.Node)
2222
GenericNode = typing.Union[GenericRawNode, GenericResolvedNode]

conda-recipe/meta.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ build:
1818

1919
requirements:
2020
host:
21-
- python >=3.7
21+
- python >=3.7,<3.10
2222
- pip
2323
run:
24-
- python >=3.7
24+
- python >=3.7,<3.10
2525
- tqdm
2626
- typer
2727
{% for dep in setup_py_data['install_requires'] %}

tests/build_spec/test_build_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def _test_build_spec(spec_path, out_path, weight_type, tensorflow_version=None,
4444
cite=cite,
4545
root=model_spec.root_path,
4646
weight_type=weight_type_,
47-
output_path=out_path
47+
output_path=out_path,
4848
)
4949
if tensorflow_version is not None:
5050
kwargs["tensorflow_version"] = tensorflow_version

tests/resource_io/test_load_rdf.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,6 @@ def test_load_non_existing_rdf():
1818
load_resource_description(spec_path)
1919

2020

21-
def test_load_non_valid_rdf_name_no_suffix():
22-
from bioimageio.core import load_resource_description
23-
24-
with NamedTemporaryFile() as f:
25-
spec_path = pathlib.Path(f.name)
26-
27-
with pytest.raises(ValidationError):
28-
load_resource_description(spec_path)
29-
30-
31-
def test_load_non_valid_rdf_name_invalid_suffix():
32-
from bioimageio.core import load_resource_description
33-
34-
with NamedTemporaryFile(suffix=".invalid_suffix") as f:
35-
spec_path = pathlib.Path(f.name)
36-
37-
with pytest.raises(ValidationError):
38-
load_resource_description(spec_path)
39-
40-
4121
def test_load_raw_model(any_model):
4222
from bioimageio.core import load_raw_resource_description
4323

tests/test_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
from bioimageio.core import load_resource_description
77

88

9+
def test_validate_model(unet2d_nuclei_broad_model):
10+
ret = subprocess.run(["bioimageio", "validate", unet2d_nuclei_broad_model])
11+
assert ret.returncode == 0
12+
13+
914
def test_cli_test_model(unet2d_nuclei_broad_model):
1015
ret = subprocess.run(["bioimageio", "test-model", unet2d_nuclei_broad_model])
1116
assert ret.returncode == 0

0 commit comments

Comments
 (0)