22import re
33import traceback
44import warnings
5+ from copy import deepcopy
56from pathlib import Path
67from typing import List , Optional , Tuple , Union
78
@@ -271,7 +272,7 @@ def debug_model(
271272
272273 Returns dict with tensors "inputs", "inputs_processed", "outputs_raw", "outputs", "expected" and "diff".
273274 """
274- inputs : Optional = None
275+ inputs_raw : Optional = None
275276 inputs_processed : Optional = None
276277 outputs_raw : Optional = None
277278 outputs : Optional = None
@@ -291,10 +292,20 @@ def debug_model(
291292 xr .DataArray (np .load (str (in_path )), dims = input_spec .axes )
292293 for in_path , input_spec in zip (model .test_inputs , model .inputs )
293294 ]
295+ input_dict = {input_spec .name : input for input_spec , input in zip (model .inputs , inputs )}
294296
295- inputs_processed , stats = prediction_pipeline .preprocess (* inputs )
297+ # keep track of the non-processed inputs
298+ inputs_raw = [deepcopy (input ) for input in inputs ]
299+
300+ computed_measures = {}
301+
302+ prediction_pipeline .apply_preprocessing (input_dict , computed_measures )
303+ inputs_processed = list (input_dict .values ())
296304 outputs_raw = prediction_pipeline .predict (* inputs_processed )
297- outputs , _ = prediction_pipeline .postprocess (* outputs_raw , input_sample_statistics = stats )
305+ output_dict = {output_spec .name : deepcopy (output ) for output_spec , output in zip (model .outputs , outputs_raw )}
306+ prediction_pipeline .apply_postprocessing (output_dict , computed_measures )
307+ outputs = list (output_dict .values ())
308+
298309 if isinstance (outputs , (np .ndarray , xr .DataArray )):
299310 outputs = [outputs ]
300311
@@ -311,7 +322,7 @@ def debug_model(
311322 diff .append (res - exp )
312323
313324 return {
314- "inputs" : inputs ,
325+ "inputs" : inputs_raw ,
315326 "inputs_processed" : inputs_processed ,
316327 "outputs_raw" : outputs_raw ,
317328 "outputs" : outputs ,
0 commit comments