|
2 | 2 |
|
3 | 3 | import collections.abc |
4 | 4 | import importlib.util |
| 5 | +import sys |
5 | 6 | from itertools import chain |
6 | 7 | from pathlib import Path |
7 | 8 | from typing import ( |
|
24 | 25 | from numpy.typing import NDArray |
25 | 26 | from typing_extensions import Unpack, assert_never |
26 | 27 |
|
27 | | -from bioimageio.spec._internal.io import HashKwargs, resolve |
| 28 | +from bioimageio.spec import get_validation_context |
| 29 | +from bioimageio.spec._internal.io import HashKwargs |
28 | 30 | from bioimageio.spec.common import FileDescr, FileSource, ZipPath |
29 | 31 | from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5 |
30 | 32 | from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile |
@@ -84,10 +86,51 @@ def import_callable( |
84 | 86 | def _import_from_file_impl( |
85 | 87 | source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs] |
86 | 88 | ): |
87 | | - code = resolve(source, **kwargs).path.read_text(encoding="utf-8") |
88 | | - module_globals: Dict[str, Any] = {} |
89 | | - exec(code, module_globals) |
90 | | - return module_globals[callable_name] |
| 89 | + with get_validation_context().replace(perform_io_checks=True): |
| 90 | + src_descr = FileDescr(source=source, **kwargs) |
| 91 | + assert src_descr.sha256 is not None |
| 92 | + |
| 93 | + local_source = src_descr.download() |
| 94 | + source_code = local_source.path.read_text(encoding="utf-8") |
| 95 | + |
| 96 | + module_name = local_source.original_file_name.replace("-", "_") |
| 97 | + if module_name.endswith(".py"): |
| 98 | + module_name = module_name[:-3] |
| 99 | + |
| 100 | + # make sure we have a unique module name to avoid conflicts and confusion |
| 101 | + module_name = f"{module_name}_{src_descr.sha256}" |
| 102 | + |
| 103 | + # make sure we have a valid module name |
| 104 | + if not module_name.isidentifier(): |
| 105 | + module_name = f"custom_module_{src_descr.sha256}" |
| 106 | + assert module_name.isidentifier(), module_name |
| 107 | + |
| 108 | + module = sys.modules.get(module_name) |
| 109 | + if module is None: |
| 110 | + try: |
| 111 | + module_spec = importlib.util.spec_from_loader(module_name, loader=None) |
| 112 | + assert module_spec is not None |
| 113 | + module = importlib.util.module_from_spec(module_spec) |
| 114 | + exec(source_code, module.__dict__) |
| 115 | + sys.modules[module_spec.name] = module # cache this module |
| 116 | + except Exception as e: |
| 117 | + raise ImportError( |
| 118 | + f"Failed to import {module_name[:-58]}... from {source}" |
| 119 | + ) from e |
| 120 | + |
| 121 | + try: |
| 122 | + callable_attr = getattr(module, callable_name) |
| 123 | + except AttributeError as e: |
| 124 | + raise AttributeError( |
| 125 | + f"Imported custom module `{module_name[:-58]}...` has no `{callable_name}` attribute" |
| 126 | + ) from e |
| 127 | + except Exception as e: |
| 128 | + raise AttributeError( |
| 129 | + f"Failed to access `{callable_name}` attribute from imported custom module `{module_name[:-58]}...`" |
| 130 | + ) from e |
| 131 | + |
| 132 | + else: |
| 133 | + return callable_attr |
91 | 134 |
|
92 | 135 |
|
93 | 136 | def get_axes_infos( |
|
0 commit comments