@@ -72,8 +72,8 @@ def __init__(self, inputs: List[nodes.InputTensor], outputs: List[nodes.OutputTe
7272
7373 self ._req_input_stats = {s : self ._collect_required_stats (self ._prep , s ) for s in SCOPES }
7474 self ._req_output_stats = {s : self ._collect_required_stats (self ._post , s ) for s in SCOPES }
75- if any ( self ._req_output_stats [s ] for s in SCOPES ) :
76- raise NotImplementedError ("computing statistics for output tensors not yet implemented" )
75+ if self ._req_output_stats [DATASET ] :
76+ raise NotImplementedError ("computing statistics for output tensors per dataset is not yet implemented" )
7777
7878 self ._computed_dataset_stats : Optional [Dict [str , Dict [Measure , Any ]]] = None
7979
@@ -111,8 +111,10 @@ def apply_postprocessing(
111111 ) -> Tuple [List [xr .DataArray ], Dict [str , Dict [Measure , Any ]]]:
112112 assert len (output_tensors ) == len (self .output_tensor_names )
113113 tensors = dict (zip (self .output_tensor_names , output_tensors ))
114- sample_stats = input_sample_statistics
115- sample_stats .update (self .compute_sample_statistics (tensors , self ._req_output_stats [SAMPLE ]))
114+ sample_stats = {
115+ ** input_sample_statistics ,
116+ ** self .compute_sample_statistics (tensors , self ._req_output_stats [SAMPLE ]),
117+ }
116118 for proc in self ._post :
117119 proc .set_computed_sample_statistics (sample_stats )
118120 tensors [proc .tensor_name ] = proc .apply (tensors [proc .tensor_name ])
0 commit comments