Skip to content

Commit b90b3d0

Browse files
Merge pull request #301 from jhennies/main
added model.eval() and debugged resource tests
2 parents 3cf3611 + e361482 commit b90b3d0

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

bioimageio/core/prediction_pipeline/_model_adapters/_pytorch_model_adapter.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def _load(self, *, devices: Optional[List[str]] = None):
3030
state = torch.load(weights.source, map_location=self._devices[0])
3131
self._model.load_state_dict(state)
3232

33+
self._model.eval()
3334
self._internal_output_axes = [tuple(out.axes) for out in self.bioimageio_model.outputs]
3435

3536
def _forward(self, *input_tensors: xr.DataArray) -> List[xr.DataArray]:

bioimageio/core/resource_tests.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import re
33
import traceback
44
import warnings
5+
from copy import deepcopy
56
from pathlib import Path
67
from typing import List, Optional, Tuple, Union
78

@@ -271,7 +272,7 @@ def debug_model(
271272
272273
Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
273274
"""
274-
inputs: Optional = None
275+
inputs_raw: Optional = None
275276
inputs_processed: Optional = None
276277
outputs_raw: Optional = None
277278
outputs: Optional = None
@@ -291,10 +292,20 @@ def debug_model(
291292
xr.DataArray(np.load(str(in_path)), dims=input_spec.axes)
292293
for in_path, input_spec in zip(model.test_inputs, model.inputs)
293294
]
295+
input_dict = {input_spec.name: input for input_spec, input in zip(model.inputs, inputs)}
294296

295-
inputs_processed, stats = prediction_pipeline.preprocess(*inputs)
297+
# keep track of the non-processed inputs
298+
inputs_raw = [deepcopy(input) for input in inputs]
299+
300+
computed_measures = {}
301+
302+
prediction_pipeline.apply_preprocessing(input_dict, computed_measures)
303+
inputs_processed = list(input_dict.values())
296304
outputs_raw = prediction_pipeline.predict(*inputs_processed)
297-
outputs, _ = prediction_pipeline.postprocess(*outputs_raw, input_sample_statistics=stats)
305+
output_dict = {output_spec.name: deepcopy(output) for output_spec, output in zip(model.outputs, outputs_raw)}
306+
prediction_pipeline.apply_postprocessing(output_dict, computed_measures)
307+
outputs = list(output_dict.values())
308+
298309
if isinstance(outputs, (np.ndarray, xr.DataArray)):
299310
outputs = [outputs]
300311

@@ -311,7 +322,7 @@ def debug_model(
311322
diff.append(res - exp)
312323

313324
return {
314-
"inputs": inputs,
325+
"inputs": inputs_raw,
315326
"inputs_processed": inputs_processed,
316327
"outputs_raw": outputs_raw,
317328
"outputs": outputs,

0 commit comments

Comments
 (0)