Skip to content

Commit 545ac54

Browse files
committed
improve test error handling
1 parent f22438f commit 545ac54

File tree

4 files changed

+66
-66
lines changed

4 files changed

+66
-66
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 28 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
import platform
44
import subprocess
5-
import traceback
65
import warnings
76
from io import StringIO
87
from itertools import product
@@ -47,7 +46,10 @@
4746
MismatchedElementsPerMillion,
4847
RelativeTolerance,
4948
)
50-
from bioimageio.spec._internal.validation_context import validation_context_var
49+
from bioimageio.spec._internal.validation_context import (
50+
get_validation_context,
51+
validation_context_var,
52+
)
5153
from bioimageio.spec.common import BioimageioYamlContent, PermissiveFileSource, Sha256
5254
from bioimageio.spec.model import v0_4, v0_5
5355
from bioimageio.spec.model.v0_5 import WeightsFormat
@@ -589,8 +591,17 @@ def _test_model_inference(
589591
) -> None:
590592
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
591593
logger.debug("starting '{}'", test_name)
592-
error: Optional[str] = None
593-
tb: List[str] = []
594+
errors: List[ErrorEntry] = []
595+
596+
def add_error_entry(msg: str, with_traceback: bool = False):
597+
errors.append(
598+
ErrorEntry(
599+
loc=("weights", weight_format),
600+
msg=msg,
601+
type="bioimageio.core",
602+
with_traceback=with_traceback,
603+
)
604+
)
594605

595606
try:
596607
inputs = get_test_inputs(model)
@@ -602,13 +613,15 @@ def _test_model_inference(
602613
results = prediction_pipeline.predict_sample_without_blocking(inputs)
603614

604615
if len(results.members) != len(expected.members):
605-
error = f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
616+
add_error_entry(
617+
f"Expected {len(expected.members)} outputs, but got {len(results.members)}"
618+
)
606619

607620
else:
608621
for m, expected in expected.members.items():
609622
actual = results.members.get(m)
610623
if actual is None:
611-
error = "Output tensors for test case may not be None"
624+
add_error_entry("Output tensors for test case may not be None")
612625
break
613626

614627
rtol, atol, mismatched_tol = _get_tolerance(
@@ -627,7 +640,7 @@ def _test_model_inference(
627640
a_max = abs_diff[a_max_idx].item()
628641
a_actual = actual[a_max_idx].item()
629642
a_expected = expected[a_max_idx].item()
630-
error = (
643+
add_error_entry(
631644
f"Output '{m}' disagrees with {mismatched_elements} of"
632645
+ f" {expected.size} expected values."
633646
+ f"\n Max relative difference: {r_max:.2e}"
@@ -638,30 +651,18 @@ def _test_model_inference(
638651
)
639652
break
640653
except Exception as e:
641-
if validation_context_var.get().raise_errors:
654+
if get_validation_context().raise_errors:
642655
raise e
643656

644-
error = str(e)
645-
tb = traceback.format_exception(type(e), e, e.__traceback__, chain=True)
657+
add_error_entry(str(e), with_traceback=True)
646658

647659
model.validation_summary.add_detail(
648660
ValidationDetail(
649661
name=test_name,
650662
loc=("weights", weight_format),
651-
status="passed" if error is None else "failed",
663+
status="failed" if errors else "passed",
652664
recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
653-
errors=(
654-
[]
655-
if error is None
656-
else [
657-
ErrorEntry(
658-
loc=("weights", weight_format),
659-
msg=error,
660-
type="bioimageio.core",
661-
traceback=tb,
662-
)
663-
]
664-
),
665+
errors=errors,
665666
)
666667
)
667668

@@ -816,11 +817,9 @@ def get_ns(n: int):
816817
if stop_early and error is not None:
817818
break
818819
except Exception as e:
819-
if validation_context_var.get().raise_errors:
820+
if get_validation_context().raise_errors:
820821
raise e
821822

822-
error = str(e)
823-
tb = traceback.format_tb(e.__traceback__)
824823
model.validation_summary.add_detail(
825824
ValidationDetail(
826825
name=f"Run {weight_format} inference for parametrized inputs",
@@ -829,9 +828,9 @@ def get_ns(n: int):
829828
errors=[
830829
ErrorEntry(
831830
loc=("weights", weight_format),
832-
msg=error,
831+
msg=str(e),
833832
type="bioimageio.core",
834-
traceback=tb,
833+
with_traceback=True,
835834
)
836835
],
837836
)
@@ -854,7 +853,7 @@ def _test_expected_resource_type(
854853
ErrorEntry(
855854
loc=("type",),
856855
type="type",
857-
msg=f"expected type {expected_type}, found {rd.type}",
856+
msg=f"Expected type {expected_type}, found {rd.type}",
858857
)
859858
]
860859
),

bioimageio/core/backends/_model_adapter.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import sys
12
import warnings
23
from abc import ABC, abstractmethod
34
from typing import (
@@ -87,7 +88,7 @@ def create(
8788
)
8889

8990
weights = model_description.weights
90-
errors: List[Tuple[SupportedWeightsFormat, Exception]] = []
91+
errors: List[Exception] = []
9192
weight_format_priority_order = (
9293
DEFAULT_WEIGHT_FORMAT_PRIORITY_ORDER
9394
if weight_format_priority_order is None
@@ -112,7 +113,7 @@ def create(
112113
model_description=model_description, devices=devices
113114
)
114115
except Exception as e:
115-
errors.append((wf, e))
116+
errors.append(e)
116117
elif wf == "tensorflow_saved_model_bundle":
117118
assert weights.tensorflow_saved_model_bundle is not None
118119
try:
@@ -122,7 +123,7 @@ def create(
122123
model_description=model_description, devices=devices
123124
)
124125
except Exception as e:
125-
errors.append((wf, e))
126+
errors.append(e)
126127
elif wf == "onnx":
127128
assert weights.onnx is not None
128129
try:
@@ -132,7 +133,7 @@ def create(
132133
model_description=model_description, devices=devices
133134
)
134135
except Exception as e:
135-
errors.append((wf, e))
136+
errors.append(e)
136137
elif wf == "torchscript":
137138
assert weights.torchscript is not None
138139
try:
@@ -142,7 +143,7 @@ def create(
142143
model_description=model_description, devices=devices
143144
)
144145
except Exception as e:
145-
errors.append((wf, e))
146+
errors.append(e)
146147
elif wf == "keras_hdf5":
147148
assert weights.keras_hdf5 is not None
148149
# keras can either be installed as a separate package or used as part of tensorflow
@@ -158,27 +159,24 @@ def create(
158159
model_description=model_description, devices=devices
159160
)
160161
except Exception as e:
161-
errors.append((wf, e))
162+
errors.append(e)
162163
else:
163164
assert_never(wf)
164165

165166
assert errors
166167
if len(weight_format_priority_order) == 1:
167168
assert len(errors) == 1
168-
wf, e = errors[0]
169-
raise ValueError(
170-
f"The '{wf}' model adapter could not be created"
171-
+ f" in this environment:\n{e.__class__.__name__}({e}).\n\n"
172-
) from e
169+
raise errors[0]
173170

174171
else:
175-
error_list = "\n - ".join(
176-
f"{wf}: {e.__class__.__name__}({e})" for wf, e in errors
177-
)
178-
raise ValueError(
172+
msg = (
179173
"None of the weight format specific model adapters could be created"
180-
+ f" in this environment. Errors are:\n\n{error_list}.\n\n"
174+
+ " in this environment."
181175
)
176+
if sys.version_info[:2] >= (3, 11):
177+
raise ExceptionGroup(msg, errors)
178+
else:
179+
raise ValueError(msg) from Exception(errors)
182180

183181
@final
184182
def load(self, *, devices: Optional[Sequence[str]] = None) -> None:

bioimageio/core/backends/pytorch_backend.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def load_torch_model(
9797
load_state: bool = True,
9898
devices: Optional[Sequence[Union[str, torch.device]]] = None,
9999
) -> nn.Module:
100-
arch = import_callable(
100+
custom_callable = import_callable(
101101
weight_spec.architecture,
102102
sha256=(
103103
weight_spec.architecture_sha256
@@ -110,27 +110,29 @@ def load_torch_model(
110110
if isinstance(weight_spec, v0_4.PytorchStateDictWeightsDescr)
111111
else weight_spec.architecture.kwargs
112112
)
113-
try:
114-
# calling custom user code
115-
network = arch(**model_kwargs)
116-
except Exception as e:
117-
raise RuntimeError("Failed to initialize PyTorch model") from e
118-
119-
if not isinstance(network, nn.Module):
120-
raise ValueError(
121-
f"calling {weight_spec.architecture.callable_name if isinstance(weight_spec.architecture, (v0_4.CallableFromFile, v0_4.CallableFromDepencency)) else weight_spec.architecture.callable} did not return a torch.nn.Module"
122-
)
113+
torch_model = custom_callable(**model_kwargs)
114+
115+
if not isinstance(torch_model, nn.Module):
116+
if isinstance(
117+
weight_spec.architecture,
118+
(v0_4.CallableFromFile, v0_4.CallableFromDepencency),
119+
):
120+
callable_name = weight_spec.architecture.callable_name
121+
else:
122+
callable_name = weight_spec.architecture.callable
123+
124+
raise ValueError(f"Calling {callable_name} did not return a torch.nn.Module.")
123125

124126
if load_state or devices:
125127
use_devices = get_devices(devices)
126-
network = network.to(use_devices[0])
128+
torch_model = torch_model.to(use_devices[0])
127129
if load_state:
128-
network = load_torch_state_dict(
129-
network,
130+
torch_model = load_torch_state_dict(
131+
torch_model,
130132
path=download(weight_spec).path,
131133
devices=use_devices,
132134
)
133-
return network
135+
return torch_model
134136

135137

136138
def load_torch_state_dict(

bioimageio/core/digest_spec.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -111,22 +111,23 @@ def _import_from_file_impl(
111111
module_spec = importlib.util.spec_from_loader(module_name, loader=None)
112112
assert module_spec is not None
113113
module = importlib.util.module_from_spec(module_spec)
114-
exec(source_code, module.__dict__)
114+
source_compiled = compile(
115+
source_code, str(local_source.path), "exec"
116+
) # compile source to attach file name
117+
exec(source_compiled, module.__dict__)
115118
sys.modules[module_spec.name] = module # cache this module
116119
except Exception as e:
117-
raise ImportError(
118-
f"Failed to import {module_name[:-58]}... from {source}"
119-
) from e
120+
raise ImportError(f"Failed to import {source} .") from e
120121

121122
try:
122123
callable_attr = getattr(module, callable_name)
123124
except AttributeError as e:
124125
raise AttributeError(
125-
f"Imported custom module `{module_name[:-58]}...` has no `{callable_name}` attribute"
126+
f"Imported custom module from {source} has no `{callable_name}` attribute."
126127
) from e
127128
except Exception as e:
128129
raise AttributeError(
129-
f"Failed to access `{callable_name}` attribute from imported custom module `{module_name[:-58]}...`"
130+
f"Failed to access `{callable_name}` attribute from custom module imported from {source} ."
130131
) from e
131132

132133
else:

0 commit comments

Comments
 (0)