Skip to content

Commit ecd7b09

Browse files
committed
Custom test quantity support for calibration_ecdf
1 parent 55d51df commit ecd7b09

File tree

1 file changed

+41
-1
lines changed

1 file changed

+41
-1
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from collections.abc import Mapping, Sequence
1+
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
4+
import keras
45
import matplotlib.pyplot as plt
56

67
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
@@ -13,6 +14,7 @@ def calibration_ecdf(
1314
targets: Mapping[str, np.ndarray] | np.ndarray,
1415
variable_keys: Sequence[str] = None,
1516
variable_names: Sequence[str] = None,
17+
test_quantities: dict[str, Callable] = None,
1618
difference: bool = False,
1719
stacked: bool = False,
1820
rank_type: str | np.ndarray = "fractional",
@@ -78,6 +80,18 @@ def calibration_ecdf(
7880
variable_names : list or None, optional, default: None
7981
The parameter names for nice plot titles.
8082
Inferred if None. Only relevant if `stacked=False`.
83+
test_quantities : dict or None, optional, default: None
84+
A dict that maps plot titles to functions that compute
85+
test quantities based on estimate/target draws.
86+
87+
The dict keys are automatically added to ``variable_keys``
88+
and ``variable_names``.
89+
Test quantity functions are expected to accept a dict of draws with
90+
shape ``(batch_size, ...)`` as the first (typically only)
91+
positional argument and return an NumPy array of shape
92+
``(batch_size,)``.
93+
The functions do not have to deal with an additional
94+
sample dimension, as appropriate reshaping is done internally.
8195
figsize : tuple or None, optional, default: None
8296
The figure size passed to the matplotlib constructor.
8397
Inferred if None.
@@ -120,6 +134,32 @@ def calibration_ecdf(
120134
If an unknown `rank_type` is passed.
121135
"""
122136

137+
# Optionally, compute and prepend test quantities from draws
138+
if test_quantities is not None:
139+
# Prepare empty mapping to hold test quantity values
140+
test_quantities_estimates = {}
141+
test_quantities_targets = {}
142+
143+
for key, test_quantity_func in test_quantities.items():
144+
# Apply test_quantity_func to draws
145+
tq_targets = test_quantity_func(data=targets)
146+
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
147+
148+
# We assume test_quantity_func can only handle a 1D batch_size, so estimates
149+
# which have shape (num_conditions, num_post_samples, ...) must be flattend first.
150+
num_conditions, num_post_samples = next(iter(estimates.values())).shape[:2]
151+
flattened_estimates = keras.tree.map_structure(lambda t: np.reshape(t, (-1, *t.shape[2:])), estimates)
152+
flat_tq_estimates = test_quantity_func(data=flattened_estimates)
153+
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_post_samples, 1))
154+
155+
# Add custom test quantities to variable keys and names for plotting
156+
variable_keys = list(test_quantities.keys()) + variable_keys
157+
variable_names = list(test_quantities.keys()) + variable_names
158+
159+
# Prepend test quantities to draws
160+
estimates = test_quantities_estimates | estimates
161+
targets = test_quantities_targets | targets
162+
123163
plot_data = prepare_plot_data(
124164
estimates=estimates,
125165
targets=targets,

0 commit comments

Comments
 (0)