|
66 | 66 | from typing_extensions import NotRequired, TypedDict, Unpack, assert_never, get_args |
67 | 67 |
|
68 | 68 | from bioimageio.core import __version__ |
| 69 | +from bioimageio.core.io import save_tensor |
69 | 70 |
|
70 | 71 | from ._prediction_pipeline import create_prediction_pipeline |
71 | 72 | from .axis import AxisId, BatchSize |
@@ -789,58 +790,74 @@ def add_warning_entry(msg: str): |
789 | 790 | else: |
790 | 791 | continue |
791 | 792 |
|
792 | | - expected_np = expected.data.to_numpy().astype(np.float32) |
793 | | - del expected |
794 | | - actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) |
795 | | - del actual |
| 793 | + try: |
| 794 | + expected_np = expected.data.to_numpy().astype(np.float32) |
| 795 | + del expected |
| 796 | + actual_np: NDArray[Any] = actual.data.to_numpy().astype(np.float32) |
796 | 797 |
|
797 | | - rtol, atol, mismatched_tol = _get_tolerance( |
798 | | - model, wf=weight_format, m=m, **deprecated |
799 | | - ) |
800 | | - rtol_value = rtol * abs(expected_np) |
801 | | - abs_diff = abs(actual_np - expected_np) |
802 | | - mismatched = abs_diff > atol + rtol_value |
803 | | - mismatched_elements = mismatched.sum().item() |
804 | | - if not mismatched_elements: |
805 | | - continue |
806 | | - |
807 | | - mismatched_ppm = mismatched_elements / expected_np.size * 1e6 |
808 | | - abs_diff[~mismatched] = 0 # ignore non-mismatched elements |
809 | | - |
810 | | - r_max_idx_flat = ( |
811 | | - r_diff := (abs_diff / (abs(expected_np) + 1e-6)) |
812 | | - ).argmax() |
813 | | - r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) |
814 | | - r_max = r_diff[r_max_idx].item() |
815 | | - r_actual = actual_np[r_max_idx].item() |
816 | | - r_expected = expected_np[r_max_idx].item() |
817 | | - |
818 | | - # Calculate the max absolute difference with the relative tolerance subtracted |
819 | | - abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value |
820 | | - a_max_idx = np.unravel_index( |
821 | | - abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape |
822 | | - ) |
| 798 | + rtol, atol, mismatched_tol = _get_tolerance( |
| 799 | + model, wf=weight_format, m=m, **deprecated |
| 800 | + ) |
| 801 | + rtol_value = rtol * abs(expected_np) |
| 802 | + abs_diff = abs(actual_np - expected_np) |
| 803 | + mismatched = abs_diff > atol + rtol_value |
| 804 | + mismatched_elements = mismatched.sum().item() |
| 805 | + if not mismatched_elements: |
| 806 | + continue |
823 | 807 |
|
824 | | - a_max = abs_diff[a_max_idx].item() |
825 | | - a_actual = actual_np[a_max_idx].item() |
826 | | - a_expected = expected_np[a_max_idx].item() |
827 | | - |
828 | | - msg = ( |
829 | | - f"Output '{m}' disagrees with {mismatched_elements} of" |
830 | | - + f" {expected_np.size} expected values" |
831 | | - + f" ({mismatched_ppm:.1f} ppm)." |
832 | | - + f"\n Max relative difference: {r_max:.2e}" |
833 | | - + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" |
834 | | - + f" at {dict(zip(dims, r_max_idx))}" |
835 | | - + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" |
836 | | - + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" |
837 | | - ) |
838 | | - if mismatched_ppm > mismatched_tol: |
| 808 | + actual_output_path = Path(f"actual_output_{m}_{weight_format}.npy") |
| 809 | + try: |
| 810 | + save_tensor(actual_output_path, actual) |
| 811 | + except Exception as e: |
| 812 | + logger.error( |
| 813 | + "Failed to save actual output tensor to {}: {}", |
| 814 | + actual_output_path, |
| 815 | + e, |
| 816 | + ) |
| 817 | + |
| 818 | + mismatched_ppm = mismatched_elements / expected_np.size * 1e6 |
| 819 | + abs_diff[~mismatched] = 0 # ignore non-mismatched elements |
| 820 | + |
| 821 | + r_max_idx_flat = ( |
| 822 | + r_diff := (abs_diff / (abs(expected_np) + 1e-6)) |
| 823 | + ).argmax() |
| 824 | + r_max_idx = np.unravel_index(r_max_idx_flat, r_diff.shape) |
| 825 | + r_max = r_diff[r_max_idx].item() |
| 826 | + r_actual = actual_np[r_max_idx].item() |
| 827 | + r_expected = expected_np[r_max_idx].item() |
| 828 | + |
| 829 | + # Calculate the max absolute difference with the relative tolerance subtracted |
| 830 | + abs_diff_wo_rtol: NDArray[np.float32] = abs_diff - rtol_value |
| 831 | + a_max_idx = np.unravel_index( |
| 832 | + abs_diff_wo_rtol.argmax(), abs_diff_wo_rtol.shape |
| 833 | + ) |
| 834 | + |
| 835 | + a_max = abs_diff[a_max_idx].item() |
| 836 | + a_actual = actual_np[a_max_idx].item() |
| 837 | + a_expected = expected_np[a_max_idx].item() |
| 838 | + except Exception as e: |
| 839 | + msg = f"Output '{m}' disagrees with expected values." |
839 | 840 | add_error_entry(msg) |
840 | 841 | if stop_early: |
841 | 842 | break |
842 | 843 | else: |
843 | | - add_warning_entry(msg) |
| 844 | + msg = ( |
| 845 | + f"Output '{m}' disagrees with {mismatched_elements} of" |
| 846 | + + f" {expected_np.size} expected values" |
| 847 | + + f" ({mismatched_ppm:.1f} ppm)." |
| 848 | + + f"\n Max relative difference: {r_max:.2e}" |
| 849 | + + rf" (= \|{r_actual:.2e} - {r_expected:.2e}\|/\|{r_expected:.2e} + 1e-6\|)" |
| 850 | + + f" at {dict(zip(dims, r_max_idx))}" |
| 851 | + + f"\n Max absolute difference not accounted for by relative tolerance: {a_max:.2e}" |
| 852 | + + rf" (= \|{a_actual:.7e} - {a_expected:.7e}\|) at {dict(zip(dims, a_max_idx))}" |
| 853 | + + f"\n Saved actual output to {actual_output_path}." |
| 854 | + ) |
| 855 | + if mismatched_ppm > mismatched_tol: |
| 856 | + add_error_entry(msg) |
| 857 | + if stop_early: |
| 858 | + break |
| 859 | + else: |
| 860 | + add_warning_entry(msg) |
844 | 861 |
|
845 | 862 | except Exception as e: |
846 | 863 | if get_validation_context().raise_errors: |
|
0 commit comments