|
1 | | -from collections.abc import Sequence, Mapping |
| 1 | +from collections.abc import Callable, Mapping, Sequence |
2 | 2 |
|
3 | 3 | import matplotlib.pyplot as plt |
4 | 4 | import numpy as np |
|
8 | 8 |
|
9 | 9 | from bayesflow.utils import logging |
10 | 10 | from bayesflow.utils import prepare_plot_data, add_titles_and_labels, prettify_subplots |
| 11 | +from bayesflow.utils.dict_utils import compute_test_quantities |
11 | 12 |
|
12 | 13 |
|
13 | 14 | def calibration_histogram( |
14 | 15 | estimates: Mapping[str, np.ndarray] | np.ndarray, |
15 | 16 | targets: Mapping[str, np.ndarray] | np.ndarray, |
16 | 17 | variable_keys: Sequence[str] = None, |
17 | 18 | variable_names: Sequence[str] = None, |
| 19 | + test_quantities: dict[str, Callable] = None, |
18 | 20 | figsize: Sequence[float] = None, |
19 | 21 | num_bins: int = 10, |
20 | 22 | binomial_interval: float = 0.99, |
@@ -46,6 +48,18 @@ def calibration_histogram( |
46 | 48 | By default, select all keys. |
47 | 49 | variable_names : list or None, optional, default: None |
48 | 50 | The parameter names for nice plot titles. Inferred if None |
| 51 | + test_quantities : dict or None, optional, default: None |
| 52 | + A dict that maps plot titles to functions that compute |
| 53 | + test quantities based on estimate/target draws. |
| 54 | +
|
| 55 | + The dict keys are automatically added to ``variable_keys`` |
| 56 | + and ``variable_names``. |
| 57 | + Test quantity functions are expected to accept a dict of draws with |
| 58 | + shape ``(batch_size, ...)`` as the first (typically only) |
| 59 | + positional argument and return an NumPy array of shape |
| 60 | + ``(batch_size,)``. |
| 61 | + The functions do not have to deal with an additional |
| 62 | + sample dimension, as appropriate reshaping is done internally. |
49 | 63 | figsize : tuple or None, optional, default : None |
50 | 64 | The figure size passed to the matplotlib constructor. Inferred if None |
51 | 65 | num_bins : int, optional, default: 10 |
@@ -75,6 +89,20 @@ def calibration_histogram( |
75 | 89 | If there is a deviation form the expected shapes of `estimates` and `targets`. |
76 | 90 | """ |
77 | 91 |
|
| 92 | + # Optionally, compute and prepend test quantities from draws |
| 93 | + if test_quantities is not None: |
| 94 | + updated_data = compute_test_quantities( |
| 95 | + targets=targets, |
| 96 | + estimates=estimates, |
| 97 | + variable_keys=variable_keys, |
| 98 | + variable_names=variable_names, |
| 99 | + test_quantities=test_quantities, |
| 100 | + ) |
| 101 | + variable_names = updated_data["variable_names"] |
| 102 | + variable_keys = updated_data["variable_keys"] |
| 103 | + estimates = updated_data["estimates"] |
| 104 | + targets = updated_data["targets"] |
| 105 | + |
78 | 106 | plot_data = prepare_plot_data( |
79 | 107 | estimates=estimates, |
80 | 108 | targets=targets, |
|
0 commit comments