1- from collections .abc import Mapping , Sequence
1+ from collections .abc import Callable , Mapping , Sequence
22
33import numpy as np
44from scipy .stats import binom
55
6- from ...utils .dict_utils import dicts_to_arrays
6+ from ...utils .dict_utils import dicts_to_arrays , compute_test_quantities
77
88
99def calibration_log_gamma (
1010 estimates : Mapping [str , np .ndarray ] | np .ndarray ,
1111 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1212 variable_keys : Sequence [str ] = None ,
1313 variable_names : Sequence [str ] = None ,
14+ test_quantities : dict [str , Callable ] = None ,
1415 num_null_draws : int = 1000 ,
1516 quantile : float = 0.05 ,
1617):
@@ -41,6 +42,18 @@ def calibration_log_gamma(
4142 By default, select all keys.
4243 variable_names : Sequence[str], optional (default = None)
4344 Optional variable names to show in the output.
45+ test_quantities : dict or None, optional, default: None
46+ A dict that maps plot titles to functions that compute
47+ test quantities based on estimate/target draws.
48+
49+ The dict keys are automatically added to ``variable_keys``
50+ and ``variable_names``.
51+ Test quantity functions are expected to accept a dict of draws with
52+ shape ``(batch_size, ...)`` as the first (typically only)
53+ positional argument and return an NumPy array of shape
54+ ``(batch_size,)``.
55+ The functions do not have to deal with an additional
56+ sample dimension, as appropriate reshaping is done internally.
4457 quantile : float in (0, 1), optional, default 0.05
4558 The quantile from the null distribution to be used as a threshold.
4659 A lower quantile increases sensitivity to deviations from uniformity.
@@ -57,6 +70,21 @@ def calibration_log_gamma(
5770 - "variable_names" : str
5871 The (inferred) variable names.
5972 """
73+
74+ # Optionally, compute and prepend test quantities from draws
75+ if test_quantities is not None :
76+ updated_data = compute_test_quantities (
77+ targets = targets ,
78+ estimates = estimates ,
79+ variable_keys = variable_keys ,
80+ variable_names = variable_names ,
81+ test_quantities = test_quantities ,
82+ )
83+ variable_names = updated_data ["variable_names" ]
84+ variable_keys = updated_data ["variable_keys" ]
85+ estimates = updated_data ["estimates" ]
86+ targets = updated_data ["targets" ]
87+
6088 samples = dicts_to_arrays (
6189 estimates = estimates ,
6290 targets = targets ,
0 commit comments