Skip to content

Commit e46b962

Browse files
introduce VariableArray class (#321)
* introduce VariableArray class * enable also plotting the prior via pairs_posterior * further fixes to pairs_posterior * fix issue #324 * Small refactor, squeezing still unclear * Sneak in small change in tutorial name * Cleanup function --------- Co-authored-by: stefanradev93 <[email protected]>
1 parent 3cd8087 commit e46b962

File tree

9 files changed

+171
-87
lines changed

9 files changed

+171
-87
lines changed

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,9 @@ conda env create --file environment.yaml --name bayesflow
9494
Check out some of our walk-through notebooks below. We are actively working on porting all notebooks to the new interface so more will be available soon!
9595

9696
1. [Linear regression starter example](examples/Linear_Regression_Starter.ipynb)
97-
2. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
98-
3. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
99-
4. [SBML model using an external simulator](examples/From_ABC_to_BayesFlow.ipynb)
97+
2. [From ABC to BayesFlow](examples/From_ABC_to_BayesFlow.ipynb)
98+
3. [Two moons starter example](examples/Two_Moons_Starter.ipynb)
99+
4. [SIR model with custom summary network](examples/SIR_Posterior_Estimation.ipynb)
100100
5. [Hyperparameter optimization](examples/Hyperparameter_Optimization.ipynb)
101101
6. [Bayesian experimental design](examples/Bayesian_Experimental_Design.ipynb)
102102
7. [Simple model comparison example (One-Sample T-Test)](examples/One_Sample_TTest.ipynb)

bayesflow/diagnostics/metrics/calibration_error.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,4 +88,5 @@ def calibration_error(
8888
# Aggregate errors across alpha
8989
error = aggregation(absolute_errors, axis=0)
9090

91+
variable_names = samples["estimates"].variable_names
9192
return {"values": error, "metric_name": "Calibration Error", "variable_names": variable_names}

bayesflow/diagnostics/metrics/posterior_contraction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,5 @@ def posterior_contraction(
5858
prior_vars = samples["targets"].var(axis=0, keepdims=True, ddof=1)
5959
contraction = 1 - (post_vars / prior_vars)
6060
contraction = aggregation(contraction, axis=0)
61-
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": samples["variable_names"]}
61+
variable_names = samples["estimates"].variable_names
62+
return {"values": contraction, "metric_name": "Posterior Contraction", "variable_names": variable_names}

bayesflow/diagnostics/metrics/root_mean_squared_error.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,5 @@ def root_mean_squared_error(
6565
metric_name = "RMSE"
6666

6767
rmse = aggregation(rmse, axis=0)
68-
return {"values": rmse, "metric_name": metric_name, "variable_names": samples["variable_names"]}
68+
variable_names = samples["estimates"].variable_names
69+
return {"values": rmse, "metric_name": metric_name, "variable_names": variable_names}

bayesflow/diagnostics/plots/calibration_ecdf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def calibration_ecdf(
143143

144144
# Plot individual ecdf of parameters
145145
for j in range(ranks.shape[-1]):
146-
ecdf_single = np.sort(ranks[:, j])
146+
ecdf_single = np.pad(np.sort(ranks[:, j]), (1, 1), constant_values=(0, 1))
147147
xx = ecdf_single
148148
yy = np.arange(1, xx.shape[-1] + 1) / float(xx.shape[-1])
149149

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 31 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,12 @@ def pairs_posterior(
2020
variable_keys: Sequence[str] = None,
2121
variable_names: Sequence[str] = None,
2222
height: int = 3,
23+
post_color: str | tuple = "#132a70",
24+
prior_color: str | tuple = "gray",
25+
alpha=0.9,
2326
label_fontsize: int = 14,
2427
tick_fontsize: int = 12,
25-
# arguments related to priors which is currently unused
26-
# legend_fontsize: int = 16,
27-
# post_color: str | tuple = "#132a70",
28-
# prior_color: str | tuple = "gray",
29-
# post_alpha: float = 0.9,
30-
# prior_alpha: float = 0.7,
28+
legend_fontsize: int = 14,
3129
**kwargs,
3230
) -> sns.PairGrid:
3331
"""Generates a bivariate pair plot given posterior draws and optional prior or prior draws.
@@ -57,10 +55,12 @@ def pairs_posterior(
5755
The color for the posterior histograms and KDEs
5856
priors_color : str, optional, default: gray
5957
The color for the optional prior histograms and KDEs
60-
post_alpha : float in [0, 1], optonal, default: 0.9
58+
post_alpha : float in [0, 1], optional, default: 0.9
6159
The opacity of the posterior plots
62-
prior_alpha : float in [0, 1], optonal, default: 0.7
60+
prior_alpha : float in [0, 1], optional, default: 0.7
6361
The opacity of the prior plots
62+
**kwargs : dict, optional, default: {}
63+
Further optional keyword arguments propagated to `_pairs_samples`
6464
6565
Returns
6666
-------
@@ -75,6 +75,7 @@ def pairs_posterior(
7575
plot_data = dicts_to_arrays(
7676
estimates=estimates,
7777
targets=targets,
78+
priors=priors,
7879
dataset_ids=dataset_id,
7980
variable_keys=variable_keys,
8081
variable_names=variable_names,
@@ -90,52 +91,33 @@ def pairs_posterior(
9091
g = _pairs_samples(
9192
plot_data=plot_data,
9293
height=height,
94+
color=post_color,
95+
color2=prior_color,
96+
alpha=alpha,
9397
label_fontsize=label_fontsize,
9498
tick_fontsize=tick_fontsize,
99+
legend_fontsize=legend_fontsize,
95100
**kwargs,
96101
)
97102

98-
# add priors
99-
if priors is not None:
100-
# TODO: integrate priors into plot_data and then use
101-
# proper coloring of posterior vs. prior using the hue argument in PairGrid
102-
raise ValueError("Plotting prior samples is not yet implemented.")
103-
104-
"""
105-
# this is currently not working as expected as it doesn't show the off diagonal plots
106-
prior_samples_df = pd.DataFrame(priors, columns=plot_data["variable_names"])
107-
g.data = prior_samples_df
108-
g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1)
109-
g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1)
110-
111-
# Add legend to differentiate between prior and posterior
112-
handles = [
113-
Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha),
114-
Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha),
115-
]
116-
handles_names = ["Posterior", "Prior"]
117-
if targets is not None:
118-
handles.append(Line2D(xdata=[], ydata=[], color="black", lw=3, linestyle="--"))
119-
handles_names.append("True Parameter")
120-
plt.legend(handles=handles, labels=handles_names, fontsize=legend_fontsize, loc="center right")
121-
"""
122-
123-
# add true parameters
124-
if plot_data["targets"] is not None:
125-
# TODO: also add true parameters to the off diagonal plots?
126-
127-
# drop dataset axis if it is still present but of length 1
128-
targets_shape = plot_data["targets"].shape
129-
if len(targets_shape) == 2 and targets_shape[0] == 1:
130-
plot_data["targets"] = np.squeeze(plot_data["targets"], axis=0)
131-
132-
# Custom function to plot true parameters on the diagonal
133-
def plot_true_params(x, **kwargs):
134-
param = x.iloc[0] # Get the single true value for the diagonal
135-
plt.axvline(param, color="black", linestyle="--") # Add vertical line
136-
137-
# only plot on the diagonal a vertical line for the true parameter
138-
g.data = pd.DataFrame(plot_data["targets"][np.newaxis], columns=plot_data["variable_names"])
103+
targets = plot_data.get("targets")
104+
if targets is not None:
105+
# Ensure targets is at least 2D
106+
if targets.ndim == 1:
107+
targets = np.atleast_2d(targets)
108+
109+
# Create DataFrame with variable names as columns
110+
g.data = pd.DataFrame(targets, columns=targets.variable_names)
111+
g.data["_source"] = "True Parameter"
139112
g.map_diag(plot_true_params)
140113

141114
return g
115+
116+
117+
def plot_true_params(x, **kwargs):
118+
"""Custom function to plot true parameters on the diagonal."""
119+
120+
# hue needs to be added to handle the case of plotting both posterior and prior
121+
param = x.iloc[0] # Get the single true value for the diagonal
122+
# only plot on the diagonal a vertical line for the true parameter
123+
plt.axvline(param, color="black", linestyle="--")

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 62 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,11 @@ def _pairs_samples(
6868
plot_data: dict,
6969
height: float = 2.5,
7070
color: str | tuple = "#132a70",
71+
color2: str | tuple = "gray",
7172
alpha: float = 0.9,
7273
label_fontsize: int = 14,
7374
tick_fontsize: int = 12,
75+
legend_fontsize: int = 14,
7476
**kwargs,
7577
) -> sns.PairGrid:
7678
# internal version of pairs_samples creating the seaborn plot
@@ -87,45 +89,83 @@ def _pairs_samples(
8789
f"your samples array has a shape of {estimates_shape}."
8890
)
8991

92+
variable_names = plot_data["estimates"].variable_names
93+
9094
# Convert samples to pd.DataFrame
91-
data_to_plot = pd.DataFrame(plot_data["estimates"], columns=plot_data["variable_names"])
95+
if plot_data["priors"] is not None:
96+
# differentiate posterior from prior draws
97+
# row bind posterior and prior draws
98+
samples = np.vstack((plot_data["priors"], plot_data["estimates"]))
99+
data_to_plot = pd.DataFrame(samples, columns=variable_names)
100+
101+
# ensure that the source of the samples is stored
102+
source_prior = np.repeat("Prior", plot_data["priors"].shape[0])
103+
source_post = np.repeat("Posterior", plot_data["estimates"].shape[0])
104+
data_to_plot["_source"] = np.concatenate((source_prior, source_post))
105+
data_to_plot["_source"] = pd.Categorical(data_to_plot["_source"], categories=["Prior", "Posterior"])
106+
107+
# initialize plot
108+
g = sns.PairGrid(
109+
data_to_plot,
110+
height=height,
111+
hue="_source",
112+
palette=[color2, color],
113+
**kwargs,
114+
)
92115

93-
# initialize plot
94-
artist = sns.PairGrid(data_to_plot, height=height, **kwargs)
116+
else:
117+
# plot just the one set of distributions
118+
data_to_plot = pd.DataFrame(plot_data["estimates"], columns=variable_names)
95119

96-
# Generate grids
97-
# in the off diagonal plots, the grids appears in front of the points/densities
98-
# TODO: can we put the grid in the background somehow?
99-
dim = artist.axes.shape[0]
100-
for i in range(dim):
101-
for j in range(dim):
102-
artist.axes[i, j].grid(alpha=0.5)
120+
# initialize plot
121+
g = sns.PairGrid(data_to_plot, height=height, **kwargs)
103122

104123
# add histograms + KDEs to the diagonal
105-
artist.map_diag(sns.histplot, fill=True, color=color, alpha=alpha, kde=True)
124+
g.map_diag(
125+
sns.histplot,
126+
fill=True,
127+
kde=True,
128+
color=color,
129+
alpha=alpha,
130+
stat="density",
131+
common_norm=False,
132+
)
133+
134+
# add scatterplots to the upper diagonal
135+
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
106136

107-
# Incorporate exceptions for generating KDE plots
137+
# add KDEs to the lower diagonal
108138
try:
109-
artist.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha)
139+
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha)
110140
except Exception as e:
111141
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
112-
artist.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
142+
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
113143

114-
artist.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
144+
# need to add legend here such that colors are recognized
145+
if plot_data["priors"] is not None:
146+
g.add_legend(fontsize=legend_fontsize, loc="center right")
147+
g._legend.set_title(None)
115148

116-
dim = artist.axes.shape[0]
149+
# Generate grids
150+
dim = g.axes.shape[0]
151+
for i in range(dim):
152+
for j in range(dim):
153+
g.axes[i, j].grid(alpha=0.5)
154+
g.axes[i, j].set_axisbelow(True)
155+
156+
dim = g.axes.shape[0]
117157
for i in range(dim):
118158
# Modify tick sizes
119159
for j in range(i + 1):
120-
artist.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
121-
artist.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
160+
g.axes[i, j].tick_params(axis="both", which="major", labelsize=tick_fontsize)
161+
g.axes[i, j].tick_params(axis="both", which="minor", labelsize=tick_fontsize)
122162

123163
# adjust font size of labels
124164
# the labels themselves remain the same as before, i.e., variable_names
125-
artist.axes[i, 0].set_ylabel(plot_data["variable_names"][i], fontsize=label_fontsize)
126-
artist.axes[dim - 1, i].set_xlabel(plot_data["variable_names"][i], fontsize=label_fontsize)
165+
g.axes[i, 0].set_ylabel(variable_names[i], fontsize=label_fontsize)
166+
g.axes[dim - 1, i].set_xlabel(variable_names[i], fontsize=label_fontsize)
127167

128168
# Return figure
129-
artist.tight_layout()
169+
g.tight_layout()
130170

131-
return artist
171+
return g

0 commit comments

Comments
 (0)