Skip to content

Commit e093619

Browse files
Add debug_model function
1 parent 7b52092 commit e093619

File tree

2 files changed

+63
-6
lines changed

2 files changed

+63
-6
lines changed

bioimageio/core/prediction_pipeline/_prediction_pipeline.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import abc
22
import math
33
from dataclasses import dataclass
4-
from typing import List, Optional, Tuple
4+
from typing import List, Optional, Tuple, Dict, Any
55

66
import xarray as xr
77
from marshmallow import missing
88

99
from bioimageio.core.resource_io import nodes
10+
from bioimageio.core.statistical_measures import Measure
1011
from ._combined_processing import CombinedProcessing
1112
from ._model_adapters import ModelAdapter, create_model_adapter
1213
from ..resource_io.nodes import InputTensor, Model, OutputTensor
@@ -94,13 +95,15 @@ def forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
9495
prediction = self.predict(*preprocessed)
9596
return self._processing.apply_postprocessing(*prediction, input_sample_statistics=sample_stats)[0]
9697

97-
def preprocess(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
98+
def preprocess(self, *input_tensors: xr.DataArray) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
9899
"""Apply preprocessing."""
99-
return self._processing.apply_preprocessing(*input_tensors)[0]
100+
return self._processing.apply_preprocessing(*input_tensors)
100101

101-
def postprocess(self, *input_tensors: xr.DataArray, input_sample_statistics) -> List[xr.DataArray]:
102+
def postprocess(
103+
self, *input_tensors: xr.DataArray, input_sample_statistics
104+
) -> Tuple[List[xr.DataArray], Dict[str, Dict[Measure, Any]]]:
102105
"""Apply postprocessing."""
103-
return self._processing.apply_postprocessing(*input_tensors, input_sample_statistics=input_sample_statistics)[0]
106+
return self._processing.apply_postprocessing(*input_tensors, input_sample_statistics=input_sample_statistics)
104107

105108
def __call__(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:
106109
return self.forward(*input_tensors)

bioimageio/core/resource_tests.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import traceback
2-
import warnings
32
from pathlib import Path
43
from typing import List, Optional, Union
54

@@ -79,3 +78,58 @@ def test_resource(
7978
# todo: add tests for non-model resources
8079

8180
return {"error": error, "traceback": tb}
81+
82+
83+
def debug_model(
84+
model_rdf: Union[RawResourceDescription, ResourceDescription, URI, Path, str],
85+
*,
86+
weight_format: Optional[WeightsFormat] = None,
87+
devices: Optional[List[str]] = None,
88+
):
89+
"""Run the model test and return dict with inputs, results, expected results and intermediates.
90+
91+
Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
92+
"""
93+
inputs: Optional = None
94+
inputs_processed: Optional = None
95+
outputs_raw: Optional = None
96+
outputs: Optional = None
97+
expected: Optional = None
98+
diff: Optional = None
99+
100+
model = load_resource_description(model_rdf)
101+
if not isinstance(model, Model):
102+
raise ValueError(f"Not a bioimageio.model: {model_rdf}")
103+
104+
prediction_pipeline = create_prediction_pipeline(
105+
bioimageio_model=model, devices=devices, weight_format=weight_format
106+
)
107+
inputs = [xr.DataArray(np.load(str(in_path)), dims=input_spec.axes)
108+
for in_path, input_spec in zip(model.test_inputs, model.inputs)]
109+
110+
inputs_processed, stats = prediction_pipeline.preprocess(*inputs)
111+
outputs_raw = prediction_pipeline.predict(*inputs_processed)
112+
outputs, _ = prediction_pipeline.postprocess(*outputs_raw, input_sample_statistics=stats)
113+
if isinstance(outputs, (np.ndarray, xr.DataArray)):
114+
outputs = [outputs]
115+
116+
expected = [xr.DataArray(np.load(str(out_path)), dims=output_spec.axes)
117+
for out_path, output_spec in zip(model.test_outputs, model.outputs)]
118+
if len(outputs) != len(expected):
119+
error = (
120+
f"Number of outputs and number of expected outputs disagree: {len(outputs)} != {len(expected)}"
121+
)
122+
print(error)
123+
else:
124+
diff = []
125+
for res, exp in zip(outputs, expected):
126+
diff.append(res - exp)
127+
128+
return {
129+
"inputs": inputs,
130+
"inputs_processed": inputs_processed,
131+
"outputs_raw": outputs_raw,
132+
"outputs": outputs,
133+
"expected": expected,
134+
"diff": diff
135+
}

0 commit comments

Comments
 (0)