diff --git a/bayesflow/diagnostics/plots/calibration_ecdf.py b/bayesflow/diagnostics/plots/calibration_ecdf.py index d39d225a8..915687818 100644 --- a/bayesflow/diagnostics/plots/calibration_ecdf.py +++ b/bayesflow/diagnostics/plots/calibration_ecdf.py @@ -1,6 +1,7 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence import numpy as np +import keras import matplotlib.pyplot as plt from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots @@ -13,6 +14,7 @@ def calibration_ecdf( targets: Mapping[str, np.ndarray] | np.ndarray, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, + test_quantities: dict[str, Callable] = None, difference: bool = False, stacked: bool = False, rank_type: str | np.ndarray = "fractional", @@ -78,6 +80,18 @@ def calibration_ecdf( variable_names : list or None, optional, default: None The parameter names for nice plot titles. Inferred if None. Only relevant if `stacked=False`. + test_quantities : dict or None, optional, default: None + A dict that maps plot titles to functions that compute + test quantities based on estimate/target draws. + + The dict keys are automatically added to ``variable_keys`` + and ``variable_names``. + Test quantity functions are expected to accept a dict of draws with + shape ``(batch_size, ...)`` as the first (typically only) + positional argument and return an NumPy array of shape + ``(batch_size,)``. + The functions do not have to deal with an additional + sample dimension, as appropriate reshaping is done internally. figsize : tuple or None, optional, default: None The figure size passed to the matplotlib constructor. Inferred if None. @@ -120,6 +134,36 @@ def calibration_ecdf( If an unknown `rank_type` is passed. """ + # Optionally, compute and prepend test quantities from draws + if test_quantities is not None: + test_quantities_estimates = {} + test_quantities_targets = {} + + for key, test_quantity_fn in test_quantities.items(): + # Apply test_quantity_func to ground-truths + tq_targets = test_quantity_fn(data=targets) + test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1) + + # # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape + num_conditions, num_samples = next(iter(estimates.values())).shape[:2] + flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates) + flat_tq_estimates = test_quantity_fn(data=flattened_estimates) + test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1)) + + # Add custom test quantities to variable keys and names for plotting + # keys and names are set to the test_quantities dict keys + test_quantities_names = list(test_quantities.keys()) + + if variable_keys is None: + variable_keys = list(estimates.keys()) + + if isinstance(variable_names, list): + variable_names = test_quantities_names + variable_names + + variable_keys = test_quantities_names + variable_keys + estimates = test_quantities_estimates | estimates + targets = test_quantities_targets | targets + plot_data = prepare_plot_data( estimates=estimates, targets=targets, diff --git a/bayesflow/utils/dict_utils.py b/bayesflow/utils/dict_utils.py index 09fa4b853..5d085f0ec 100644 --- a/bayesflow/utils/dict_utils.py +++ b/bayesflow/utils/dict_utils.py @@ -282,6 +282,10 @@ def dicts_to_arrays( Ground-truth values corresponding to the estimates. Must match the structure and dimensionality of `estimates` in terms of first and last axis. + priors : dict[str, ndarray] or ndarray, optional (default = None) + Prior draws. Must match the structure and dimensionality + of `estimates` in terms of first and last axis. + dataset_ids : Sequence of integers indexing the datasets to select (default = None). By default, use all datasets. diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index aaaca44fe..8de2e4b77 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -23,7 +23,7 @@ def prepare_plot_data( figsize: tuple = None, stacked: bool = False, default_name: str = "v", -) -> Mapping[str, Any]: +) -> dict[str, Any]: """ Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name generation, layout configuration, figure initialization, and collapsing of axes. @@ -56,6 +56,12 @@ def prepare_plot_data( Whether the plots are stacked horizontally default_name : str, optional (default = "v") The default name to use for estimates if None provided + + Returns + ------- + plot_data : dict[str, Any] + A dictionary containing all preprocessed data and plotting objects required for visualization, + including estimates, targets, variable names, figure, axes, and layout configuration. """ plot_data = dicts_to_arrays( diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index f9f29492f..8d4b7883b 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -1,4 +1,5 @@ import bayesflow as bf +import numpy as np import pytest @@ -16,6 +17,8 @@ def test_backend(): def test_calibration_ecdf(random_estimates, random_targets, var_names): + print(random_estimates, random_targets, var_names) + # basic functionality: automatic variable names out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets) assert len(out.axes) == num_variables(random_estimates) @@ -46,6 +49,22 @@ def test_calibration_ecdf(random_estimates, random_targets, var_names): # cannot infer the variable names from an array so default names are used assert out.axes[1].title._text == "v_1" + # test quantities plots are shown + test_quantities = { + r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1), + r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1), + } + out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets, test_quantities=test_quantities) + assert len(out.axes) == len(test_quantities) + num_variables(random_estimates) + assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$" + assert out.axes[-1].title._text == r"sigma" + + # test plot titles changed to variable_names in case test quantities exist + out = bf.diagnostics.plots.calibration_ecdf( + random_estimates, random_targets, test_quantities=test_quantities, variable_names=var_names + ) + assert out.axes[-1].title._text == r"$\sigma$" + def test_calibration_histogram(random_estimates, random_targets): # basic functionality: automatic variable names