Skip to content

Commit d1a27cd

Browse files
committed
improve error messages
1 parent 0c35b86 commit d1a27cd

File tree

3 files changed

+55
-7
lines changed

3 files changed

+55
-7
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -642,7 +642,7 @@ def _test_model_inference(
642642
raise e
643643

644644
error = str(e)
645-
tb = traceback.format_tb(e.__traceback__)
645+
tb = traceback.format_exception(type(e), e, e.__traceback__, chain=True)
646646

647647
model.validation_summary.add_detail(
648648
ValidationDetail(

bioimageio/core/backends/pytorch_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,12 @@ def load_torch_model(
110110
if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
111111
else weight_spec.architecture.kwargs
112112
)
113-
network = arch(**model_kwargs)
113+
try:
114+
# calling custom user code
115+
network = arch(**model_kwargs)
116+
except Exception as e:
117+
raise RuntimeError("Failed to initialize PyTorch model") from e
118+
114119
if not isinstance(network, nn.Module):
115120
raise ValueError(
116121
f"calling {weight_spec.architecture.callable_name if isinstance(weight_spec.architecture, (v0_4.CallableFromFile, v0_4.CallableFromDepencency)) else weight_spec.architecture.callable} did not return a torch.nn.Module"

bioimageio/core/digest_spec.py

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import collections.abc
44
import importlib.util
5+
import sys
56
from itertools import chain
67
from pathlib import Path
78
from typing import (
@@ -24,7 +25,8 @@
2425
from numpy.typing import NDArray
2526
from typing_extensions import Unpack, assert_never
2627

27-
from bioimageio.spec._internal.io import HashKwargs, resolve
28+
from bioimageio.spec import get_validation_context
29+
from bioimageio.spec._internal.io import HashKwargs
2830
from bioimageio.spec.common import FileDescr, FileSource, ZipPath
2931
from bioimageio.spec.model import AnyModelDescr, v0_4, v0_5
3032
from bioimageio.spec.model.v0_4 import CallableFromDepencency, CallableFromFile
@@ -84,10 +86,51 @@ def import_callable(
8486
def _import_from_file_impl(
8587
source: FileSource, callable_name: str, **kwargs: Unpack[HashKwargs]
8688
):
87-
code = resolve(source, **kwargs).path.read_text(encoding="utf-8")
88-
module_globals: Dict[str, Any] = {}
89-
exec(code, module_globals)
90-
return module_globals[callable_name]
89+
with get_validation_context().replace(perform_io_checks=True):
90+
src_descr = FileDescr(source=source, **kwargs)
91+
assert src_descr.sha256 is not None
92+
93+
local_source = src_descr.download()
94+
source_code = local_source.path.read_text(encoding="utf-8")
95+
96+
module_name = local_source.original_file_name.replace("-", "_")
97+
if module_name.endswith(".py"):
98+
module_name = module_name[:-3]
99+
100+
# make sure we have a unique module name to avoid conflicts and confusion
101+
module_name = f"{module_name}_{src_descr.sha256}"
102+
103+
# make sure we have a valid module name
104+
if not module_name.isidentifier():
105+
module_name = f"custom_module_{src_descr.sha256}"
106+
assert module_name.isidentifier(), module_name
107+
108+
module = sys.modules.get(module_name)
109+
if module is None:
110+
try:
111+
module_spec = importlib.util.spec_from_loader(module_name, loader=None)
112+
assert module_spec is not None
113+
module = importlib.util.module_from_spec(module_spec)
114+
exec(source_code, module.__dict__)
115+
sys.modules[module_spec.name] = module # cache this module
116+
except Exception as e:
117+
raise ImportError(
118+
f"Failed to import {module_name[:-58]}... from {source}"
119+
) from e
120+
121+
try:
122+
callable_attr = getattr(module, callable_name)
123+
except AttributeError as e:
124+
raise AttributeError(
125+
f"Imported custom module `{module_name[:-58]}...` has no `{callable_name}` attribute"
126+
) from e
127+
except Exception as e:
128+
raise AttributeError(
129+
f"Failed to access `{callable_name}` attribute from imported custom module `{module_name[:-58]}...`"
130+
) from e
131+
132+
else:
133+
return callable_attr
91134

92135

93136
def get_axes_infos(

0 commit comments

Comments
 (0)