Skip to content

Commit e161f56

Browse files
committed
more get_reader updates
1 parent d79ca05 commit e161f56

File tree

6 files changed

+48
-46
lines changed

6 files changed

+48
-46
lines changed

bioimageio/core/cli.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,12 @@
5757
update_hashes,
5858
)
5959
from bioimageio.spec._internal.io import is_yaml_value
60-
from bioimageio.spec._internal.io_basics import ZipPath
6160
from bioimageio.spec._internal.io_utils import open_bioimageio_yaml
6261
from bioimageio.spec._internal.types import NotEmpty
6362
from bioimageio.spec.dataset import DatasetDescr
6463
from bioimageio.spec.model import ModelDescr, v0_4, v0_5
6564
from bioimageio.spec.notebook import NotebookDescr
66-
from bioimageio.spec.utils import download, ensure_description_is_model, write_yaml
65+
from bioimageio.spec.utils import ensure_description_is_model, get_reader, write_yaml
6766

6867
from .commands import WeightFormatArgAll, WeightFormatArgAny, package, test
6968
from .common import MemberId, SampleId, SupportedWeightsFormat
@@ -487,16 +486,12 @@ def _example(self):
487486
example_path.mkdir(exist_ok=True)
488487

489488
for t, src in zip(input_ids, example_inputs):
490-
local = download(src).path
491-
dst = Path(f"{example_path}/{t}/001{''.join(local.suffixes)}")
489+
reader = get_reader(src)
490+
dst = Path(f"{example_path}/{t}/001{reader.suffix}")
492491
dst.parent.mkdir(parents=True, exist_ok=True)
493492
inputs001.append(dst.as_posix())
494-
if isinstance(local, Path):
495-
shutil.copy(local, dst)
496-
elif isinstance(local, ZipPath):
497-
_ = local.root.extract(local.at, path=dst)
498-
else:
499-
assert_never(local)
493+
with dst.open("wb") as f:
494+
shutil.copyfileobj(reader, f)
500495

501496
inputs = [tuple(inputs001)]
502497
output_pattern = f"{example_path}/outputs/{{output_id}}/{{sample_id}}.tif"

bioimageio/core/digest_spec.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,13 +104,10 @@ def _import_from_file_impl(
104104
assert module_name.isidentifier(), module_name
105105

106106
source_bytes = reader.read()
107-
# with NamedTemporaryFile(
108107

109108
module = sys.modules.get(module_name)
110109
if module is None:
111110
try:
112-
# local_source.write(source_bytes)
113-
# local_source.flush()
114111
if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
115112
module_path = TemporaryDirectory(
116113
prefix=module_name,

bioimageio/core/io.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from bioimageio.spec._internal.io import get_reader, interprete_file_source
2424
from bioimageio.spec._internal.type_guards import is_ndarray
2525
from bioimageio.spec.common import (
26+
BytesReader,
2627
FileSource,
2728
HttpUrl,
2829
PermissiveFileSource,
@@ -279,10 +280,16 @@ def load_dataset_stat(path: Path):
279280
return {e.measure: e.value for e in seq}
280281

281282

282-
def ensure_unzipped(source: Union[PermissiveFileSource, ZipPath], folder: Path):
283-
"""unzip a (downloaded) **source** to a file in **folder** if source is a zip archive.
284-
Always returns the path to the unzipped source (maybe source itself)"""
285-
weights_reader = get_reader(source)
283+
def ensure_unzipped(
284+
source: Union[PermissiveFileSource, ZipPath, BytesReader], folder: Path
285+
):
286+
"""unzip a (downloaded) **source** to a file in **folder** if source is a zip archive
287+
otherwise copy **source** to a file in **folder**."""
288+
if isinstance(source, BytesReader):
289+
weights_reader = source
290+
else:
291+
weights_reader = get_reader(source)
292+
286293
out_path = folder / (
287294
weights_reader.original_file_name or f"file{weights_reader.suffix}"
288295
)

bioimageio/core/weight_converters/keras_to_tensorflow.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import os
22
import shutil
33
from pathlib import Path
4+
from tempfile import TemporaryDirectory
45
from typing import Union, no_type_check
56
from zipfile import ZipFile
67

78
import tensorflow # pyright: ignore[reportMissingTypeStubs]
89

9-
from bioimageio.spec._internal.io import download
1010
from bioimageio.spec._internal.version_type import Version
1111
from bioimageio.spec.common import ZipPath
1212
from bioimageio.spec.model.v0_5 import (
@@ -70,7 +70,7 @@ def convert(
7070
raise ValueError("Missing Keras Hdf5 weights to convert from.")
7171

7272
weight_spec = model_descr.weights.keras_hdf5
73-
weight_path = download(weight_spec.source).path
73+
weight_reader = weight_spec.get_reader()
7474

7575
if weight_spec.tensorflow_version:
7676
model_tf_major_ver = int(weight_spec.tensorflow_version.major)
@@ -79,30 +79,34 @@ def convert(
7979
f"Tensorflow major versions of model {model_tf_major_ver} is not {tf_major_ver}"
8080
)
8181

82-
if tf_major_ver == 1:
83-
if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1:
84-
raise NotImplementedError(
85-
"Weight conversion for models with multiple inputs or outputs is not yet implemented."
86-
)
87-
88-
input_name = str(
89-
d.id
90-
if isinstance((d := model_descr.inputs[0]), InputTensorDescr)
91-
else d.name
92-
)
93-
output_name = str(
94-
d.id
95-
if isinstance((d := model_descr.outputs[0]), OutputTensorDescr)
96-
else d.name
97-
)
98-
return _convert_tf1(
99-
ensure_unzipped(weight_path, Path("bioimageio_unzipped_tf_weights")),
100-
output_path,
101-
input_name,
102-
output_name,
82+
with TemporaryDirectory(ignore_cleanup_errors=True) as temp_dir:
83+
local_weights = ensure_unzipped(
84+
weight_reader, Path(temp_dir) / "bioimageio_unzipped_tf_weights"
10385
)
104-
else:
105-
return _convert_tf2(weight_path, output_path)
86+
if tf_major_ver == 1:
87+
if len(model_descr.inputs) != 1 or len(model_descr.outputs) != 1:
88+
raise NotImplementedError(
89+
"Weight conversion for models with multiple inputs or outputs is not yet implemented."
90+
)
91+
92+
input_name = str(
93+
d.id
94+
if isinstance((d := model_descr.inputs[0]), InputTensorDescr)
95+
else d.name
96+
)
97+
output_name = str(
98+
d.id
99+
if isinstance((d := model_descr.outputs[0]), OutputTensorDescr)
100+
else d.name
101+
)
102+
return _convert_tf1(
103+
ensure_unzipped(local_weights, Path("bioimageio_unzipped_tf_weights")),
104+
output_path,
105+
input_name,
106+
output_name,
107+
)
108+
else:
109+
return _convert_tf2(local_weights, output_path)
106110

107111

108112
def _convert_tf2(

bioimageio/core/weight_converters/torchscript_to_onnx.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import torch.jit
44

55
from bioimageio.spec.model.v0_5 import ModelDescr, OnnxWeightsDescr
6-
from bioimageio.spec.utils import download
76

87
from .. import __version__
98
from ..digest_spec import get_member_id, get_test_inputs
@@ -55,8 +54,8 @@ def convert(
5554
]
5655
inputs_torch = [torch.from_numpy(ipt) for ipt in inputs_numpy]
5756

58-
weight_path = download(torchscript_descr).path
59-
model = torch.jit.load(weight_path) # type: ignore
57+
weight_reader = torchscript_descr.get_reader()
58+
model = torch.jit.load(weight_reader) # type: ignore
6059
model.to("cpu")
6160
model = model.eval() # type: ignore
6261

tests/test_bioimageio_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def _get_latest_rdf_sources():
1818
for entry in entries:
1919
version = entry["versions"][0]
2020
ret[f"{entry['concept']}/{version['v']}"] = (
21-
HttpUrl(version["source"]), # pyright: ignore[reportCallIssue]
21+
HttpUrl(version["source"]),
2222
Sha256(version["sha256"]),
2323
)
2424

0 commit comments

Comments
 (0)