Skip to content

Commit 85663d0

Browse files
committed
update _import_from_file_impl
1 parent 39e3da7 commit 85663d0

File tree

1 file changed

+31
-34
lines changed

1 file changed

+31
-34
lines changed

bioimageio/core/digest_spec.py

Lines changed: 31 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
from __future__ import annotations
22

33
import collections.abc
4-
import hashlib
54
import importlib.util
65
import sys
76
from itertools import chain
87
from pathlib import Path
8+
from tempfile import NamedTemporaryFile, TemporaryDirectory
99
from typing import (
1010
Any,
1111
Callable,
@@ -19,6 +19,7 @@
1919
Tuple,
2020
Union,
2121
)
22+
from zipfile import ZipFile, is_zipfile
2223

2324
import numpy as np
2425
import xarray as xr
@@ -35,9 +36,8 @@
3536
ArchitectureFromLibraryDescr,
3637
ParameterizedSize_N,
3738
)
38-
from bioimageio.spec.utils import download, load_array
39+
from bioimageio.spec.utils import load_array
3940

40-
from ._settings import settings
4141
from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis
4242
from .block_meta import split_multiple_shapes_into_blocks
4343
from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks
@@ -89,41 +89,46 @@ def _import_from_file_impl(
8989
):
9090
src_descr = FileDescr(source=source, **kwargs)
9191
# ensure sha is valid even if perform_io_checks=False
92-
src_descr.validate_sha256()
92+
# or the source has changed since last sha computation
93+
src_descr.validate_sha256(force_recompute=True)
9394
assert src_descr.sha256 is not None
95+
source_sha = src_descr.sha256
9496

95-
local_source = src_descr.download()
96-
97-
source_bytes = local_source.path.read_bytes()
98-
assert isinstance(source_bytes, bytes)
99-
source_sha = hashlib.sha256(source_bytes).hexdigest()
100-
97+
reader = src_descr.get_reader()
10198
# make sure we have unique module name
102-
module_name = f"{local_source.path.stem}_{source_sha}"
99+
module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}"
103100

104-
# make sure we have a valid module name
101+
# make sure we have a unique and valid module name
105102
if not module_name.isidentifier():
106103
module_name = f"custom_module_{source_sha}"
107104
assert module_name.isidentifier(), module_name
108105

106+
source_bytes = reader.read()
107+
# with NamedTemporaryFile(
108+
109109
module = sys.modules.get(module_name)
110110
if module is None:
111111
try:
112-
if isinstance(local_source.path, Path):
113-
module_path = local_source.path
114-
elif isinstance(local_source.path, ZipPath):
115-
# save extract source to cache
116-
# loading from a file from disk ensure we get readable tracebacks
117-
# if any errors occur
118-
module_path = (
119-
settings.cache_path / f"{source_sha}-{local_source.path.name}"
112+
# local_source.write(source_bytes)
113+
# local_source.flush()
114+
if reader.original_file_name.endswith(".zip") or is_zipfile(reader):
115+
module_path = TemporaryDirectory(
116+
prefix=module_name,
117+
delete=False,
118+
ignore_cleanup_errors=True,
120119
)
121-
_ = module_path.write_bytes(source_bytes)
120+
ZipFile(reader).extractall(path=module_path.name)
122121
else:
123-
assert_never(local_source.path)
122+
module_path = NamedTemporaryFile(
123+
mode="wb",
124+
suffix=reader.suffix,
125+
prefix=f"{module_name}_",
126+
delete=False,
127+
)
128+
_ = module_path.write(source_bytes)
124129

125130
importlib_spec = importlib.util.spec_from_file_location(
126-
module_name, module_path
131+
module_name, module_path.name
127132
)
128133

129134
if importlib_spec is None:
@@ -378,21 +383,13 @@ def get_tensor(
378383

379384
if isinstance(src, Tensor):
380385
return src
381-
382-
if isinstance(src, xr.DataArray):
386+
elif isinstance(src, xr.DataArray):
383387
return Tensor.from_xarray(src)
384-
385-
if isinstance(src, np.ndarray):
388+
elif isinstance(src, np.ndarray):
386389
return Tensor.from_numpy(src, dims=get_axes_infos(ipt))
387-
388-
if isinstance(src, FileDescr):
389-
src = download(src).path
390-
391-
if isinstance(src, (ZipPath, Path, str)):
390+
else:
392391
return load_tensor(src, axes=get_axes_infos(ipt))
393392

394-
assert_never(src)
395-
396393

397394
def create_sample_for_model(
398395
model: AnyModelDescr,

0 commit comments

Comments
 (0)