Skip to content

Commit 0142135

Browse files
committed
catch any errors in error message creation
1 parent c629bd5 commit 0142135

File tree

1 file changed

+63
-46
lines changed

1 file changed

+63
-46
lines changed

src/bioimageio/core/_resource_tests.py

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args
6767

6868
from bioimageio.core import __version__
69+
from bioimageio.core.io import save_tensor
6970

7071
from ._prediction_pipeline import create_prediction_pipeline
7172
from .axis import AxisId, BatchSize
@@ -789,58 +790,74 @@ def add_warning_entry(msg: str):
789790
else:
790791
continue
791792

792-
expected_np = expected.data.to_numpy().astype(np.float32)
793-
del expected
794-
actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
795-
del actual
793+
try:
794+
expected_np = expected.data.to_numpy().astype(np.float32)
795+
del expected
796+
actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32)
796797

797-
rtol, atol, mismatched_tol = _get_tolerance(
798-
model, wf=weight_format, m=m, **deprecated
799-
)
800-
rtol_value = rtol * abs(expected_np)
801-
abs_diff = abs(actual_np - expected_np)
802-
mismatched = abs_diff > atol + rtol_value
803-
mismatched_elements = mismatched.sum().item()
804-
if not mismatched_elements:
805-
continue
806-
807-
mismatched_ppm = mismatched_elements / expected_np.size * 1e6
808-
abs_diff[~mismatched] = 0 # ignore non-mismatched elements
809-
810-
r_max_idx_flat = (
811-
r_diff := (abs_diff / (abs(expected_np) + 1e-6))
812-
).argmax()
813-
r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
814-
r_max = r_diff[r_max_idx].item()
815-
r_actual = actual_np[r_max_idx].item()
816-
r_expected = expected_np[r_max_idx].item()
817-
818-
# Calculate the max absolute difference with the relative tolerance subtracted
819-
abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
820-
a_max_idx = np.unravel_index(
821-
abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
822-
)
798+
rtol, atol, mismatched_tol = _get_tolerance(
799+
model, wf=weight_format, m=m, **deprecated
800+
)
801+
rtol_value = rtol * abs(expected_np)
802+
abs_diff = abs(actual_np - expected_np)
803+
mismatched = abs_diff > atol + rtol_value
804+
mismatched_elements = mismatched.sum().item()
805+
if not mismatched_elements:
806+
continue
823807

824-
a_max = abs_diff[a_max_idx].item()
825-
a_actual = actual_np[a_max_idx].item()
826-
a_expected = expected_np[a_max_idx].item()
827-
828-
msg = (
829-
f"Output '{m}' disagrees with {mismatched_elements} of"
830-
+ f" {expected_np.size} expected values"
831-
+ f" ({mismatched_ppm:.1f} ppm)."
832-
+ f"\n Max relative difference: {r_max:.2e}"
833-
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
834-
+ f" at {dict(zip(dims, r_max_idx))}"
835-
+ f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
836-
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
837-
)
838-
if mismatched_ppm > mismatched_tol:
808+
actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy")
809+
try:
810+
save_tensor(actual_output_path, actual)
811+
except Exception as e:
812+
logger.error(
813+
"Failed to save actual output tensor to {}: {}",
814+
actual_output_path,
815+
e,
816+
)
817+
818+
mismatched_ppm = mismatched_elements / expected_np.size * 1e6
819+
abs_diff[~mismatched] = 0 # ignore non-mismatched elements
820+
821+
r_max_idx_flat = (
822+
r_diff := (abs_diff / (abs(expected_np) + 1e-6))
823+
).argmax()
824+
r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape)
825+
r_max = r_diff[r_max_idx].item()
826+
r_actual = actual_np[r_max_idx].item()
827+
r_expected = expected_np[r_max_idx].item()
828+
829+
# Calculate the max absolute difference with the relative tolerance subtracted
830+
abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value
831+
a_max_idx = np.unravel_index(
832+
abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape
833+
)
834+
835+
a_max = abs_diff[a_max_idx].item()
836+
a_actual = actual_np[a_max_idx].item()
837+
a_expected = expected_np[a_max_idx].item()
838+
except Exception as e:
839+
msg = f"Output '{m}' disagrees with expected values."
839840
add_error_entry(msg)
840841
if stop_early:
841842
break
842843
else:
843-
add_warning_entry(msg)
844+
msg = (
845+
f"Output '{m}' disagrees with {mismatched_elements} of"
846+
+ f" {expected_np.size} expected values"
847+
+ f" ({mismatched_ppm:.1f} ppm)."
848+
+ f"\n Max relative difference: {r_max:.2e}"
849+
+ rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)"
850+
+ f" at {dict(zip(dims, r_max_idx))}"
851+
+ f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}"
852+
+ rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}"
853+
+ f"\n Saved actual output to {actual_output_path}."
854+
)
855+
if mismatched_ppm > mismatched_tol:
856+
add_error_entry(msg)
857+
if stop_early:
858+
break
859+
else:
860+
add_warning_entry(msg)
844861

845862
except Exception as e:
846863
if get_validation_context().raise_errors:

0 commit comments

Comments
 (0)