|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import collections.abc |
4 | | -import hashlib |
5 | 4 | import importlib.util |
6 | 5 | import sys |
7 | 6 | from itertools import chain |
8 | 7 | from pathlib import Path |
| 8 | +from tempfile import NamedTemporaryFile, TemporaryDirectory |
9 | 9 | from typing import ( |
10 | 10 | Any, |
11 | 11 | Callable, |
|
19 | 19 | Tuple, |
20 | 20 | Union, |
21 | 21 | ) |
| 22 | +from zipfile import ZipFile, is_zipfile |
22 | 23 |
|
23 | 24 | import numpy as np |
24 | 25 | import xarray as xr |
|
35 | 36 | ArchitectureFromLibraryDescr, |
36 | 37 | ParameterizedSize_N, |
37 | 38 | ) |
38 | | -from bioimageio.spec.utils import download, load_array |
| 39 | +from bioimageio.spec.utils import load_array |
39 | 40 |
|
40 | | -from ._settings import settings |
41 | 41 | from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis |
42 | 42 | from .block_meta import split_multiple_shapes_into_blocks |
43 | 43 | from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks |
@@ -89,41 +89,46 @@ def _import_from_file_impl( |
89 | 89 | ): |
90 | 90 | src_descr = FileDescr(source=source, **kwargs) |
91 | 91 | # ensure sha is valid even if perform_io_checks=False |
92 | | - src_descr.validate_sha256() |
| 92 | + # or the source has changed since last sha computation |
| 93 | + src_descr.validate_sha256(force_recompute=True) |
93 | 94 | assert src_descr.sha256 is not None |
| 95 | + source_sha = src_descr.sha256 |
94 | 96 |
|
95 | | - local_source = src_descr.download() |
96 | | - |
97 | | - source_bytes = local_source.path.read_bytes() |
98 | | - assert isinstance(source_bytes, bytes) |
99 | | - source_sha = hashlib.sha256(source_bytes).hexdigest() |
100 | | - |
| 97 | + reader = src_descr.get_reader() |
101 | 98 | # make sure we have unique module name |
102 | | - module_name = f"{local_source.path.stem}_{source_sha}" |
| 99 | + module_name = f"{reader.original_file_name.split('.')[0]}_{source_sha}" |
103 | 100 |
|
104 | | - # make sure we have a valid module name |
| 101 | + # make sure we have a unique and valid module name |
105 | 102 | if not module_name.isidentifier(): |
106 | 103 | module_name = f"custom_module_{source_sha}" |
107 | 104 | assert module_name.isidentifier(), module_name |
108 | 105 |
|
| 106 | + source_bytes = reader.read() |
| 107 | + # with NamedTemporaryFile( |
| 108 | + |
109 | 109 | module = sys.modules.get(module_name) |
110 | 110 | if module is None: |
111 | 111 | try: |
112 | | - if isinstance(local_source.path, Path): |
113 | | - module_path = local_source.path |
114 | | - elif isinstance(local_source.path, ZipPath): |
115 | | - # save extract source to cache |
116 | | - # loading from a file from disk ensure we get readable tracebacks |
117 | | - # if any errors occur |
118 | | - module_path = ( |
119 | | - settings.cache_path / f"{source_sha}-{local_source.path.name}" |
| 112 | + # local_source.write(source_bytes) |
| 113 | + # local_source.flush() |
| 114 | + if reader.original_file_name.endswith(".zip") or is_zipfile(reader): |
| 115 | + module_path = TemporaryDirectory( |
| 116 | + prefix=module_name, |
| 117 | + delete=False, |
| 118 | + ignore_cleanup_errors=True, |
120 | 119 | ) |
121 | | - _ = module_path.write_bytes(source_bytes) |
| 120 | + ZipFile(reader).extractall(path=module_path.name) |
122 | 121 | else: |
123 | | - assert_never(local_source.path) |
| 122 | + module_path = NamedTemporaryFile( |
| 123 | + mode="wb", |
| 124 | + suffix=reader.suffix, |
| 125 | + prefix=f"{module_name}_", |
| 126 | + delete=False, |
| 127 | + ) |
| 128 | + _ = module_path.write(source_bytes) |
124 | 129 |
|
125 | 130 | importlib_spec = importlib.util.spec_from_file_location( |
126 | | - module_name, module_path |
| 131 | + module_name, module_path.name |
127 | 132 | ) |
128 | 133 |
|
129 | 134 | if importlib_spec is None: |
@@ -378,21 +383,13 @@ def get_tensor( |
378 | 383 |
|
379 | 384 | if isinstance(src, Tensor): |
380 | 385 | return src |
381 | | - |
382 | | - if isinstance(src, xr.DataArray): |
| 386 | + elif isinstance(src, xr.DataArray): |
383 | 387 | return Tensor.from_xarray(src) |
384 | | - |
385 | | - if isinstance(src, np.ndarray): |
| 388 | + elif isinstance(src, np.ndarray): |
386 | 389 | return Tensor.from_numpy(src, dims=get_axes_infos(ipt)) |
387 | | - |
388 | | - if isinstance(src, FileDescr): |
389 | | - src = download(src).path |
390 | | - |
391 | | - if isinstance(src, (ZipPath, Path, str)): |
| 390 | + else: |
392 | 391 | return load_tensor(src, axes=get_axes_infos(ipt)) |
393 | 392 |
|
394 | | - assert_never(src) |
395 | | - |
396 | 393 |
|
397 | 394 | def create_sample_for_model( |
398 | 395 | model: AnyModelDescr, |
|
0 commit comments