Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 45 additions & 1 deletion bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence

import numpy as np
import keras
import matplotlib.pyplot as plt

from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
Expand All @@ -13,6 +14,7 @@ def calibration_ecdf(
targets: Mapping[str, np.ndarray] | np.ndarray,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
test_quantities: dict[str, Callable] = None,
difference: bool = False,
stacked: bool = False,
rank_type: str | np.ndarray = "fractional",
Expand Down Expand Up @@ -78,6 +80,18 @@ def calibration_ecdf(
variable_names : list or None, optional, default: None
The parameter names for nice plot titles.
Inferred if None. Only relevant if `stacked=False`.
test_quantities : dict or None, optional, default: None
A dict that maps plot titles to functions that compute
test quantities based on estimate/target draws.

The dict keys are automatically added to ``variable_keys``
and ``variable_names``.
Test quantity functions are expected to accept a dict of draws with
shape ``(batch_size, ...)`` as the first (typically only)
positional argument and return an NumPy array of shape
``(batch_size,)``.
The functions do not have to deal with an additional
sample dimension, as appropriate reshaping is done internally.
figsize : tuple or None, optional, default: None
The figure size passed to the matplotlib constructor.
Inferred if None.
Expand Down Expand Up @@ -120,6 +134,36 @@ def calibration_ecdf(
If an unknown `rank_type` is passed.
"""

# Optionally, compute and prepend test quantities from draws
if test_quantities is not None:
test_quantities_estimates = {}
test_quantities_targets = {}

for key, test_quantity_fn in test_quantities.items():
# Apply test_quantity_func to ground-truths
tq_targets = test_quantity_fn(data=targets)
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)

# # Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))

# Add custom test quantities to variable keys and names for plotting
# keys and names are set to the test_quantities dict keys
test_quantities_names = list(test_quantities.keys())

if variable_keys is None:
variable_keys = list(estimates.keys())

if isinstance(variable_names, list):
variable_names = test_quantities_names + variable_names

variable_keys = test_quantities_names + variable_keys
estimates = test_quantities_estimates | estimates
targets = test_quantities_targets | targets

plot_data = prepare_plot_data(
estimates=estimates,
targets=targets,
Expand Down
4 changes: 4 additions & 0 deletions bayesflow/utils/dict_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ def dicts_to_arrays(
Ground-truth values corresponding to the estimates. Must match the structure and dimensionality
of `estimates` in terms of first and last axis.

priors : dict[str, ndarray] or ndarray, optional (default = None)
Prior draws. Must match the structure and dimensionality
of `estimates` in terms of first and last axis.

dataset_ids : Sequence of integers indexing the datasets to select (default = None).
By default, use all datasets.

Expand Down
8 changes: 7 additions & 1 deletion bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def prepare_plot_data(
figsize: tuple = None,
stacked: bool = False,
default_name: str = "v",
) -> Mapping[str, Any]:
) -> dict[str, Any]:
"""
Procedural wrapper that encompasses all preprocessing steps, including shape-checking, parameter name
generation, layout configuration, figure initialization, and collapsing of axes.
Expand Down Expand Up @@ -56,6 +56,12 @@ def prepare_plot_data(
Whether the plots are stacked horizontally
default_name : str, optional (default = "v")
The default name to use for estimates if None provided

Returns
-------
plot_data : dict[str, Any]
A dictionary containing all preprocessed data and plotting objects required for visualization,
including estimates, targets, variable names, figure, axes, and layout configuration.
"""

plot_data = dicts_to_arrays(
Expand Down
19 changes: 19 additions & 0 deletions tests/test_diagnostics/test_diagnostics_plots.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import bayesflow as bf
import numpy as np
import pytest


Expand All @@ -16,6 +17,8 @@ def test_backend():


def test_calibration_ecdf(random_estimates, random_targets, var_names):
print(random_estimates, random_targets, var_names)

# basic functionality: automatic variable names
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
assert len(out.axes) == num_variables(random_estimates)
Expand Down Expand Up @@ -46,6 +49,22 @@ def test_calibration_ecdf(random_estimates, random_targets, var_names):
# cannot infer the variable names from an array so default names are used
assert out.axes[1].title._text == "v_1"

# test quantities plots are shown
test_quantities = {
r"$\beta_1 + \beta_2$": lambda data: np.sum(data["beta"], axis=-1),
r"$\beta_1 \cdot \beta_2$": lambda data: np.prod(data["beta"], axis=-1),
}
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets, test_quantities=test_quantities)
assert len(out.axes) == len(test_quantities) + num_variables(random_estimates)
assert out.axes[1].title._text == r"$\beta_1 \cdot \beta_2$"
assert out.axes[-1].title._text == r"sigma"

# test plot titles changed to variable_names in case test quantities exist
out = bf.diagnostics.plots.calibration_ecdf(
random_estimates, random_targets, test_quantities=test_quantities, variable_names=var_names
)
assert out.axes[-1].title._text == r"$\sigma$"


def test_calibration_histogram(random_estimates, random_targets):
# basic functionality: automatic variable names
Expand Down