Skip to content

Commit 55ffced

Browse files
committed
implement different input modes for plotting functions
1 parent 8667646 commit 55ffced

File tree

4 files changed

+331
-73
lines changed

4 files changed

+331
-73
lines changed

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 12 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
from collections.abc import Callable, Mapping, Sequence
22

33
import numpy as np
4-
import keras
54
import matplotlib.pyplot as plt
65

7-
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots
6+
from ...utils.plot_utils import prepare_plot_data, add_titles_and_labels, prettify_subplots, compute_test_quantities
87
from ...utils.ecdf import simultaneous_ecdf_bands
98
from ...utils.ecdf.ranks import fractional_ranks, distance_ranks
109

@@ -136,38 +135,17 @@ def calibration_ecdf(
136135

137136
# Optionally, compute and prepend test quantities from draws
138137
if test_quantities is not None:
139-
test_quantities_estimates = {}
140-
test_quantities_targets = {}
141-
142-
for key, test_quantity_fn in test_quantities.items():
143-
# Apply test_quantity_func to ground-truths
144-
tq_targets = test_quantity_fn(data=targets)
145-
test_quantities_targets[key] = np.expand_dims(tq_targets, axis=1)
146-
147-
# Flatten estimates for batch processing in test_quantity_fn, apply function, and restore shape
148-
num_conditions, num_samples = next(iter(estimates.values())).shape[:2]
149-
flattened_estimates = keras.tree.map_structure(
150-
lambda t: np.reshape(t, (num_conditions * num_samples, *t.shape[2:]))
151-
if isinstance(t, np.ndarray)
152-
else t,
153-
estimates,
154-
)
155-
flat_tq_estimates = test_quantity_fn(data=flattened_estimates)
156-
test_quantities_estimates[key] = np.reshape(flat_tq_estimates, (num_conditions, num_samples, 1))
157-
158-
# Add custom test quantities to variable keys and names for plotting
159-
# keys and names are set to the test_quantities dict keys
160-
test_quantities_names = list(test_quantities.keys())
161-
162-
if variable_keys is None:
163-
variable_keys = list(estimates.keys())
164-
165-
if isinstance(variable_names, list):
166-
variable_names = test_quantities_names + variable_names
167-
168-
variable_keys = test_quantities_names + variable_keys
169-
estimates = test_quantities_estimates | estimates
170-
targets = test_quantities_targets | targets
138+
updated_data = compute_test_quantities(
139+
targets=targets,
140+
estimates=estimates,
141+
variable_keys=variable_keys,
142+
variable_names=variable_names,
143+
test_quantities=test_quantities,
144+
)
145+
variable_names = updated_data["variable_names"]
146+
variable_keys = updated_data["variable_keys"]
147+
estimates = updated_data["estimates"]
148+
targets = updated_data["targets"]
171149

172150
plot_data = prepare_plot_data(
173151
estimates=estimates,

bayesflow/diagnostics/plots/pairs_quantity.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections.abc import Sequence, Mapping
1+
from collections.abc import Callable, Sequence, Mapping
22

33
import matplotlib
44
import matplotlib.pyplot as plt
@@ -7,18 +7,22 @@
77
import pandas as pd
88
import seaborn as sns
99

10-
from bayesflow.utils.dict_utils import make_variable_array
10+
11+
from .plot_quantity import _prepare_values
1112

1213

1314
def pairs_quantity(
14-
values: Mapping[str, np.ndarray] | np.ndarray,
15+
values: Mapping[str, np.ndarray] | np.ndarray | Callable,
16+
*,
1517
targets: Mapping[str, np.ndarray] | np.ndarray,
1618
variable_keys: Sequence[str] = None,
1719
variable_names: Sequence[str] = None,
20+
estimates: Mapping[str, np.ndarray] | np.ndarray | None = None,
21+
test_quantities: dict[str, Callable] = None,
1822
height: float = 2.5,
1923
cmap: str | matplotlib.colors.Colormap = "viridis",
2024
alpha: float = 0.9,
21-
label: str = "",
25+
label: str = None,
2226
label_fontsize: int = 14,
2327
tick_fontsize: int = 12,
2428
colorbar_label_fontsize: int = 14,
@@ -28,6 +32,7 @@ def pairs_quantity(
2832
colorbar_offset: float = 0.06,
2933
vmin: float = None,
3034
vmax: float = None,
35+
default_name: str = "v",
3136
**kwargs,
3237
) -> sns.PairGrid:
3338
"""
@@ -38,25 +43,59 @@ def pairs_quantity(
3843
each parameter is plotted on the diagonal. Each column displays the
3944
values of corresponding to the parameter in the column.
4045
46+
The function supports the following different combinations to pass
47+
or compute the values:
48+
49+
1. pass `values` as an array of shape (num_datasets,) or (num_datasets, num_variables)
50+
2. pass `values` as a dictionary with the keys 'values', 'metric_name' and 'variable_names'
51+
as provided by the metrics functions. Note that the functions have to be called
52+
without aggregation to obtain value per dataset.
53+
3. pass a function to `values`, as well as `estimates`. The function should have the
54+
signature fn(estimates, targets, [aggregation]) and return an object like the
55+
`values` described in the previous options.
56+
4157
Parameters
4258
----------
43-
values : dict[str, np.ndarray],
44-
The value of the quantity to plot.
45-
targets : dict[str, np.ndarray],
59+
values : dict[str, np.ndarray] | np.ndarray | Callable,
60+
The value of the quantity to plot. One of the following:
61+
62+
1. an array of shape (num_datasets,) or (num_datasets, num_variables)
63+
2. a dictionary with the keys 'values', 'metric_name' and 'variable_names'
64+
as provided by the metrics functions. Note that the functions have to be called
65+
without aggregation to obtain value per dataset.
66+
3. a callable, requires passing `estimates` as well. The function should have the
67+
signature fn(estimates, targets, [aggregation]) and return an object like the
68+
ones described in the previous options.
69+
targets : dict[str, np.ndarray] | np.ndarray,
4670
The parameter values plotted on the axis.
4771
variable_keys : list or None, optional, default: None
4872
Select keys from the dictionary provided in samples.
4973
By default, select all keys.
5074
variable_names : list or None, optional, default: None
5175
The parameter names for nice plot titles. Inferred if None
76+
estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params), optional, default: None
77+
The posterior draws obtained from n_data_sets. Can only be supplied if
78+
`values` is of type Callable.
79+
test_quantities : dict or None, optional, default: None
80+
A dict that maps plot titles to functions that compute
81+
test quantities based on estimate/target draws.
82+
83+
The dict keys are automatically added to ``variable_keys``
84+
and ``variable_names``.
85+
Test quantity functions are expected to accept a dict of draws with
86+
shape ``(batch_size, ...)`` as the first (typically only)
87+
positional argument and return an NumPy array of shape
88+
``(batch_size,)``.
89+
The functions do not have to deal with an additional
90+
sample dimension, as appropriate reshaping is done internally.
5291
height : float, optional, default: 2.5
5392
The height of the pair plot
5493
cmap : str or Colormap, default: "viridis"
5594
The colormap for the plot.
5695
alpha : float in [0, 1], optional, default: 0.9
5796
The opacity of the plot
58-
label : str, optional, default: ""
59-
Label for the dataset to plot
97+
label : str, optional, default: None
98+
Label for the dataset to plot.
6099
label_fontsize : int, optional, default: 14
61100
The font size of the x and y-label texts (parameter names)
62101
tick_fontsize : int, optional, default: 12
@@ -77,21 +116,44 @@ def pairs_quantity(
77116
vmax : float, optional, default: None
78117
Maximum value for the colormap. If None, the maximum value is
79118
determined from `values`.
119+
default_name : str, optional (default = "v")
120+
The default name to use for estimates if None provided
80121
**kwargs : dict, optional
81122
Additional keyword arguments passed to the sns.PairGrid constructor
123+
124+
Returns
125+
-------
126+
plt.Figure
127+
The figure instance
128+
129+
Raises
130+
------
131+
ValueError
132+
If a callable is supplied as `values`, but `estimates` is None.
82133
"""
83-
values = make_variable_array(
84-
values,
134+
135+
if isinstance(values, Callable) and estimates is None:
136+
raise ValueError("Supplied a callable as `values`, but not `estimates`.")
137+
138+
d = _prepare_values(
139+
values=values,
140+
targets=targets,
141+
estimates=estimates,
85142
variable_keys=variable_keys,
86143
variable_names=variable_names,
144+
test_quantities=test_quantities,
145+
label=label,
146+
default_name=default_name,
87147
)
88-
variable_names = values.variable_names
89-
variable_keys = values.variable_keys
90-
targets = make_variable_array(
91-
targets,
92-
variable_keys=variable_keys,
93-
variable_names=variable_names,
148+
(values, targets, variable_keys, variable_names, test_quantities, label) = (
149+
d["values"],
150+
d["targets"],
151+
d["variable_keys"],
152+
d["variable_names"],
153+
d["test_quantities"],
154+
d["label"],
94155
)
156+
95157
# Convert samples to pd.DataFrame
96158
data_to_plot = pd.DataFrame(targets, columns=variable_names)
97159

@@ -110,11 +172,12 @@ def pairs_quantity(
110172
dim = g.axes.shape[0]
111173
for i in range(dim):
112174
for j in range(dim):
175+
# if one value for each variable is supplied, use it for the corresponding column
176+
row_values = values[:, j] if values.ndim == 2 else values
177+
113178
if i == j:
114179
ax = g.axes[i, j].twinx()
115-
ax.scatter(
116-
targets[:, i], values[:, i], c=values[:, i], cmap=cmap, s=4, vmin=vmin, vmax=vmax, alpha=alpha
117-
)
180+
ax.scatter(targets[:, i], values[:, i], c=row_values, cmap=cmap, s=4, vmin=vmin, vmax=vmax, alpha=alpha)
118181
ax.spines["left"].set_visible(False)
119182
ax.spines["top"].set_visible(False)
120183
ax.tick_params(axis="both", which="major", labelsize=tick_fontsize)
@@ -132,7 +195,7 @@ def pairs_quantity(
132195
g.axes[i, j].scatter(
133196
targets[:, j],
134197
targets[:, i],
135-
c=values[:, j],
198+
c=row_values,
136199
cmap=cmap,
137200
s=4,
138201
vmin=vmin,

0 commit comments

Comments
 (0)