1- from collections .abc import Mapping , Sequence
1+ from collections .abc import Callable , Mapping , Sequence
22
33import numpy as np
4+ import keras
45import matplotlib .pyplot as plt
56
67from ...utils .plot_utils import prepare_plot_data , add_titles_and_labels , prettify_subplots
@@ -13,6 +14,7 @@ def calibration_ecdf(
1314 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1415 variable_keys : Sequence [str ] = None ,
1516 variable_names : Sequence [str ] = None ,
17+ test_quantities : dict [str , Callable ] = None ,
1618 difference : bool = False ,
1719 stacked : bool = False ,
1820 rank_type : str | np .ndarray = "fractional" ,
@@ -78,6 +80,18 @@ def calibration_ecdf(
7880 variable_names : list or None, optional, default: None
7981 The parameter names for nice plot titles.
8082 Inferred if None. Only relevant if `stacked=False`.
83+ test_quantities : dict or None, optional, default: None
84+ A dict that maps plot titles to functions that compute
85+ test quantities based on estimate/target draws.
86+
87+ The dict keys are automatically added to ``variable_keys``
88+ and ``variable_names``.
89+ Test quantity functions are expected to accept a dict of draws with
90+ shape ``(batch_size, ...)`` as the first (typically only)
91+ positional argument and return an NumPy array of shape
92+ ``(batch_size,)``.
93+ The functions do not have to deal with an additional
94+ sample dimension, as appropriate reshaping is done internally.
8195 figsize : tuple or None, optional, default: None
8296 The figure size passed to the matplotlib constructor.
8397 Inferred if None.
@@ -120,6 +134,32 @@ def calibration_ecdf(
120134 If an unknown `rank_type` is passed.
121135 """
122136
137+ # Optionally, compute and prepend test quantities from draws
138+ if test_quantities is not None :
139+ # Prepare empty mapping to hold test quantity values
140+ test_quantities_estimates = {}
141+ test_quantities_targets = {}
142+
143+ for key , test_quantity_func in test_quantities .items ():
144+ # Apply test_quantity_func to draws
145+ tq_targets = test_quantity_func (data = targets )
146+ test_quantities_targets [key ] = np .expand_dims (tq_targets , axis = 1 )
147+
148+ # We assume test_quantity_func can only handle a 1D batch_size, so estimates
149+ # which have shape (num_conditions, num_post_samples, ...) must be flattend first.
150+ num_conditions , num_post_samples = next (iter (estimates .values ())).shape [:2 ]
151+ flattened_estimates = keras .tree .map_structure (lambda t : np .reshape (t , (- 1 , * t .shape [2 :])), estimates )
152+ flat_tq_estimates = test_quantity_func (data = flattened_estimates )
153+ test_quantities_estimates [key ] = np .reshape (flat_tq_estimates , (num_conditions , num_post_samples , 1 ))
154+
155+ # Add custom test quantities to variable keys and names for plotting
156+ variable_keys = list (test_quantities .keys ()) + variable_keys
157+ variable_names = list (test_quantities .keys ()) + variable_names
158+
159+ # Prepend test quantities to draws
160+ estimates = test_quantities_estimates | estimates
161+ targets = test_quantities_targets | targets
162+
123163 plot_data = prepare_plot_data (
124164 estimates = estimates ,
125165 targets = targets ,
0 commit comments