|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import collections.abc |
| 4 | +import hashlib |
4 | 5 | import importlib.util |
5 | 6 | import sys |
6 | 7 | from itertools import chain |
|
36 | 37 | ) |
37 | 38 | from bioimageio.spec.utils import download, load_array |
38 | 39 |
|
| 40 | +from ._settings import settings |
39 | 41 | from .axis import Axis, AxisId, AxisInfo, AxisLike, PerAxis |
40 | 42 | from .block_meta import split_multiple_shapes_into_blocks |
41 | 43 | from .common import Halo, MemberId, PerMember, SampleId, TotalNumberOfBlocks |
@@ -91,33 +93,50 @@ def _import_from_file_impl( |
91 | 93 | assert src_descr.sha256 is not None |
92 | 94 |
|
93 | 95 | local_source = src_descr.download() |
94 | | - source_code = local_source.path.read_text(encoding="utf-8") |
95 | 96 |
|
96 | | - module_name = local_source.original_file_name.replace("-", "_") |
97 | | - if module_name.endswith(".py"): |
98 | | - module_name = module_name[:-3] |
| 97 | + source_bytes = local_source.path.read_bytes() |
| 98 | + assert isinstance(source_bytes, bytes) |
| 99 | + source_sha = hashlib.sha256(source_bytes).hexdigest() |
99 | 100 |
|
100 | | - # make sure we have a unique module name to avoid conflicts and confusion |
101 | | - module_name = f"{module_name}_{src_descr.sha256}" |
| 101 | + # make sure we have unique module name |
| 102 | + module_name = f"{local_source.path.stem}_{source_sha}" |
102 | 103 |
|
103 | 104 | # make sure we have a valid module name |
104 | 105 | if not module_name.isidentifier(): |
105 | | - module_name = f"custom_module_{src_descr.sha256}" |
| 106 | + module_name = f"custom_module_{source_sha}" |
106 | 107 | assert module_name.isidentifier(), module_name |
107 | 108 |
|
108 | 109 | module = sys.modules.get(module_name) |
109 | 110 | if module is None: |
110 | 111 | try: |
111 | | - module_spec = importlib.util.spec_from_loader(module_name, loader=None) |
112 | | - assert module_spec is not None |
113 | | - module = importlib.util.module_from_spec(module_spec) |
114 | | - source_compiled = compile( |
115 | | - source_code, str(local_source.path), "exec" |
116 | | - ) # compile source to attach file name |
117 | | - exec(source_compiled, module.__dict__) |
118 | | - sys.modules[module_spec.name] = module # cache this module |
| 112 | + if isinstance(local_source.path, Path): |
| 113 | + module_path = local_source.path |
| 114 | + elif isinstance(local_source.path, ZipPath): |
| 115 | + # save extract source to cache |
| 116 | + # loading from a file from disk ensure we get readable tracebacks |
| 117 | + # if any errors occur |
| 118 | + module_path = ( |
| 119 | + settings.cache_path / f"{source_sha}-{local_source.path.name}" |
| 120 | + ) |
| 121 | + _ = module_path.write_bytes(source_bytes) |
| 122 | + else: |
| 123 | + assert_never(local_source.path) |
| 124 | + |
| 125 | + importlib_spec = importlib.util.spec_from_file_location( |
| 126 | + module_name, module_path |
| 127 | + ) |
| 128 | + |
| 129 | + if importlib_spec is None: |
| 130 | + raise ImportError(f"Failed to import {source}.") |
| 131 | + |
| 132 | + module = importlib.util.module_from_spec(importlib_spec) |
| 133 | + assert importlib_spec.loader is not None |
| 134 | + importlib_spec.loader.exec_module(module) |
| 135 | + |
119 | 136 | except Exception as e: |
120 | 137 | raise ImportError(f"Failed to import {source} .") from e |
| 138 | + else: |
| 139 | + sys.modules[module_name] = module # cache this module |
121 | 140 |
|
122 | 141 | try: |
123 | 142 | callable_attr = getattr(module, callable_name) |
|
0 commit comments