Skip to content

Commit 57ea3ff

Browse files
committed
warn about tolerated mismatched elements
1 parent c10c677 commit 57ea3ff

File tree

2 files changed

+94
-32
lines changed

2 files changed

+94
-32
lines changed

bioimageio/core/_resource_tests.py

Lines changed: 66 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
overload,
2222
)
2323

24+
import xarray as xr
2425
from loguru import logger
2526
from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
2627

@@ -55,6 +56,7 @@
5556
InstalledPackage,
5657
ValidationDetail,
5758
ValidationSummary,
59+
WarningEntry,
5860
)
5961

6062
from ._prediction_pipeline import create_prediction_pipeline
@@ -510,7 +512,7 @@ def load_description_and_test(
510512

511513
enable_determinism(determinism, weight_formats=weight_formats)
512514
for w in weight_formats:
513-
_test_model_inference(rd, w, devices, **deprecated)
515+
_test_model_inference(rd, w, devices, stop_early=stop_early, **deprecated)
514516
if stop_early and rd.validation_summary.status == "failed":
515517
break
516518

@@ -587,14 +589,16 @@ def _test_model_inference(
587589
model: Union[v0_4.ModelDescr, v0_5.ModelDescr],
588590
weight_format: SupportedWeightsFormat,
589591
devices: Optional[Sequence[str]],
592+
stop_early: bool,
590593
**deprecated: Unpack[DeprecatedKwargs],
591594
) -> None:
592595
test_name = f"Reproduce test outputs from test inputs ({weight_format})"
593596
logger.debug("starting '{}'", test_name)
594-
errors: List[ErrorEntry] = []
597+
error_entries: List[ErrorEntry] = []
598+
warning_entries: List[WarningEntry] = []
595599

596600
def add_error_entry(msg: str, with_traceback: bool = False):
597-
errors.append(
601+
error_entries.append(
598602
ErrorEntry(
599603
loc=("weights", weight_format),
600604
msg=msg,
@@ -603,6 +607,15 @@ def add_error_entry(msg: str, with_traceback: bool = False):
603607
)
604608
)
605609

610+
def add_warning_entry(msg: str):
611+
warning_entries.append(
612+
WarningEntry(
613+
loc=("weights", weight_format),
614+
msg=msg,
615+
type="bioimageio.core",
616+
)
617+
)
618+
606619
try:
607620
inputs = get_test_inputs(model)
608621
expected = get_test_outputs(model)
@@ -622,34 +635,58 @@ def add_error_entry(msg: str, with_traceback: bool = False):
622635
actual = results.members.get(m)
623636
if actual is None:
624637
add_error_entry("Output tensors for test case may not be None")
625-
break
638+
if stop_early:
639+
break
640+
else:
641+
continue
626642

627643
rtol, atol, mismatched_tol = _get_tolerance(
628644
model, wf=weight_format, m=m, **deprecated
629645
)
630-
mismatched = (abs_diff := abs(actual - expected)) > atol + rtol * abs(
631-
expected
632-
)
646+
rtol_value = rtol * abs(expected)
647+
abs_diff = abs(actual - expected)
648+
mismatched = abs_diff > atol + rtol_value
633649
mismatched_elements = mismatched.sum().item()
634-
if mismatched_elements / expected.size > mismatched_tol / 1e6:
635-
r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
636-
r_max = r_diff[r_max_idx].item()
637-
r_actual = actual[r_max_idx].item()
638-
r_expected = expected[r_max_idx].item()
639-
a_max_idx = abs_diff.argmax()
640-
a_max = abs_diff[a_max_idx].item()
641-
a_actual = actual[a_max_idx].item()
642-
a_expected = expected[a_max_idx].item()
643-
add_error_entry(
644-
f"Output '{m}' disagrees with {mismatched_elements} of"
645-
+ f" {expected.size} expected values."
646-
+ f"\n Max relative difference: {r_max:.2e}"
647-
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
648-
+ f" at {r_max_idx}"
649-
+ f"\n Max absolute difference: {a_max:.2e}"
650-
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
651-
)
652-
break
650+
if not mismatched_elements:
651+
continue
652+
653+
mismatched_ppm = mismatched_elements / expected.size * 1e6
654+
abs_diff[~mismatched] = 0 # ignore non-mismatched elements
655+
656+
r_max_idx = (r_diff := (abs_diff / (abs(expected) + 1e-6))).argmax()
657+
r_max = r_diff[r_max_idx].item()
658+
r_actual = actual[r_max_idx].item()
659+
r_expected = expected[r_max_idx].item()
660+
661+
# Calculate the max absolute difference with the relative tolerance subtracted
662+
abs_diff_wo_rtol: xr.DataArray = xr.ufuncs.maximum(
663+
(abs_diff - rtol_value).data, 0
664+
)
665+
a_max_idx = {
666+
AxisId(k): int(v) for k, v in abs_diff_wo_rtol.argmax().items()
667+
}
668+
669+
a_max = abs_diff[a_max_idx].item()
670+
a_actual = actual[a_max_idx].item()
671+
a_expected = expected[a_max_idx].item()
672+
673+
msg = (
674+
f"Output '{m}' disagrees with {mismatched_elements} of"
675+
+ f" {expected.size} expected values"
676+
+ f" ({mismatched_ppm:.1f} ppm)."
677+
+ f"\n Max relative difference: {r_max:.2e}"
678+
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
679+
+ f" at {r_max_idx}"
680+
+ f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
681+
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {a_max_idx}"
682+
)
683+
if mismatched_ppm > mismatched_tol:
684+
add_error_entry(msg)
685+
if stop_early:
686+
break
687+
else:
688+
add_warning_entry(msg)
689+
653690
except Exception as e:
654691
if get_validation_context().raise_errors:
655692
raise e
@@ -660,9 +697,10 @@ def add_error_entry(msg: str, with_traceback: bool = False):
660697
ValidationDetail(
661698
name=test_name,
662699
loc=("weights", weight_format),
663-
status="failed" if errors else "passed",
700+
status="failed" if error_entries else "passed",
664701
recommended_env=get_conda_env(entry=dict(model.weights)[weight_format]),
665-
errors=errors,
702+
errors=error_entries,
703+
warnings=warning_entries,
666704
)
667705
)
668706

bioimageio/core/tensor.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,15 @@ def __array__(self, dtype: DTypeLike = None):
6666
return np.asarray(self._data, dtype=dtype)
6767

6868
def __getitem__(
69-
self, key: Union[SliceInfo, slice, int, PerAxis[Union[SliceInfo, slice, int]]]
69+
self,
70+
key: Union[
71+
SliceInfo,
72+
slice,
73+
int,
74+
PerAxis[Union[SliceInfo, slice, int]],
75+
Tensor,
76+
xr.DataArray,
77+
],
7078
) -> Self:
7179
if isinstance(key, SliceInfo):
7280
key = slice(*key)
@@ -75,11 +83,27 @@ def __getitem__(
7583
a: s if isinstance(s, int) else s if isinstance(s, slice) else slice(*s)
7684
for a, s in key.items()
7785
}
86+
elif isinstance(key, Tensor):
87+
key = key._data
88+
7889
return self.__class__.from_xarray(self._data[key])
7990

80-
def __setitem__(self, key: PerAxis[Union[SliceInfo, slice]], value: Tensor) -> None:
81-
key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
82-
self._data[key] = value._data
91+
def __setitem__(
92+
self,
93+
key: Union[PerAxis[Union[SliceInfo, slice]], Tensor, xr.DataArray],
94+
value: Union[Tensor, xr.DataArray, float, int],
95+
) -> None:
96+
if isinstance(key, Tensor):
97+
key = key._data
98+
elif isinstance(key, xr.DataArray):
99+
pass
100+
else:
101+
key = {a: s if isinstance(s, slice) else slice(*s) for a, s in key.items()}
102+
103+
if isinstance(value, Tensor):
104+
value = value._data
105+
106+
self._data[key] = value
83107

84108
def __len__(self) -> int:
85109
return len(self.data)

0 commit comments

Comments
 (0)