@@ -136,40 +136,31 @@ def calibration_ecdf(
136136
137137 # Optionally, compute and prepend test quantities from draws
138138 if test_quantities is not None :
139- # Prepare empty mapping to hold test quantity values
140139 test_quantities_estimates = {}
141140 test_quantities_targets = {}
142141
143- for key , test_quantity_func in test_quantities .items ():
144- # Apply test_quantity_func to draws
145- tq_targets = test_quantity_func (data = targets )
142+ for key , test_quantity_fn in test_quantities .items ():
143+ # Apply test_quantity_func to ground-truths
144+ tq_targets = test_quantity_fn (data = targets )
146145 test_quantities_targets [key ] = np .expand_dims (tq_targets , axis = 1 )
147146
148- # We assume test_quantity_func can only handle a 1D batch_size, so estimates
149- # which have shape (num_conditions, num_samples, ...) must be flattend first.
147+ # # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
150148 num_conditions , num_samples = next (iter (estimates .values ())).shape [:2 ]
151149 flattened_estimates = keras .tree .map_structure (lambda t : np .reshape (t , (- 1 , * t .shape [2 :])), estimates )
152- flat_tq_estimates = test_quantity_func (data = flattened_estimates )
150+ flat_tq_estimates = test_quantity_fn (data = flattened_estimates )
153151 test_quantities_estimates [key ] = np .reshape (flat_tq_estimates , (num_conditions , num_samples , 1 ))
154152
155153 # Add custom test quantities to variable keys and names for plotting
156154 # keys and names are set to the test_quantities dict keys
157155 test_quantities_names = list (test_quantities .keys ())
158156
159- # By default all keys in estimates are selected
160157 if variable_keys is None :
161158 variable_keys = list (estimates .keys ())
162159
163- # If variable_names are present, concatenate them to the test_quantities_names
164160 if isinstance (variable_names , list ):
165161 variable_names = test_quantities_names + variable_names
166- # If variable_names are None, they will stay None here and are subsequently inferred.
167162
168- # After the defaults are handled, we simply concatenate the keys
169- # for test quantities and regular variables.
170163 variable_keys = test_quantities_names + variable_keys
171-
172- # Prepend test quantities to draws
173164 estimates = test_quantities_estimates | estimates
174165 targets = test_quantities_targets | targets
175166
0 commit comments