Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 41 additions & 1 deletion bayesflow/diagnostics/plots/calibration_ecdf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence

import numpy as np
import keras
import matplotlib.pyplot as plt

from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
Expand All @@ -13,6 +14,7 @@ def calibration_ecdf(
targets: Mapping[str, np.ndarray] | np.ndarray,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
test_quantities: dict[str, Callable] = None,
difference: bool = False,
stacked: bool = False,
rank_type: str | np.ndarray = "fractional",
Expand Down Expand Up @@ -78,6 +80,18 @@ def calibration_ecdf(
variable_names : list or None, optional, default: None
The parameter names for nice plot titles.
Inferred if None. Only relevant if `stacked=False`.
test_quantities : dict or None, optional, default: None
A dict that maps plot titles to functions that compute
test quantities based on estimate/target draws.

The dict keys are automatically added to ``variable_keys``
and ``variable_names``.
Test quantity functions are expected to accept a dict of draws with
shape ``(batch_size, ...)`` as the first (typically only)
positional argument and return an NumPy array of shape
``(batch_size,)``.
The functions do not have to deal with an additional
sample dimension, as appropriate reshaping is done internally.
figsize : tuple or None, optional, default: None
The figure size passed to the matplotlib constructor.
Inferred if None.
Expand Down Expand Up @@ -120,6 +134,32 @@ def calibration_ecdf(
If an unknown `rank_type` is passed.
"""

# Optionally, compute and prepend test quantities from draws
if test_quantities is not None:
# Prepare empty mapping to hold test quantity values
test_quantities_estimates = {}
test_quantities_targets = {}

for key, test_quantity_func in test_quantities.items():
# Apply test_quantity_func to draws
tq_targets = test_quantity_func(data=targets)
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)

# We assume test_quantity_func can only handle a 1D batch_size, so estimates
# which have shape (num_conditions, num_samples, ...) must be flattend first.
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
flat_tq_estimates = test_quantity_func(data=flattened_estimates)
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))

# Add custom test quantities to variable keys and names for plotting
variable_keys = list(test_quantities.keys()) + variable_keys
variable_names = list(test_quantities.keys()) + variable_names

# Prepend test quantities to draws
estimates = test_quantities_estimates | estimates
targets = test_quantities_targets | targets

plot_data = prepare_plot_data(
estimates=estimates,
targets=targets,
Expand Down