Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def calibration_ecdf_from_quantiles(
fill_color: str = "grey",
num_row: int = None,
num_col: int = None,
markersize: float = None,
**kwargs,
) -> plt.Figure:
"""
Expand Down Expand Up @@ -97,6 +98,8 @@ def calibration_ecdf_from_quantiles(
num_col : int, optional, default: None
The number of columns for the subplots.
Dynamically determined if None.
markersize : float, optional, default: None
The marker size in points.
**kwargs : dict, optional, default: {}
Keyword arguments can be passed to control the behavior of
ECDF simultaneous band computation through the ``ecdf_bands_kwargs``
Expand Down Expand Up @@ -142,11 +145,15 @@ def calibration_ecdf_from_quantiles(

if stacked:
if j == 0:
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs")
plot_data["axes"][0].plot(
xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDFs"
)
else:
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95)
plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95)
else:
plot_data["axes"].flat[j].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDF")
plot_data["axes"].flat[j].plot(
xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDF"
)

# Compute uniform ECDF and bands
alpha, z, L, U = pointwise_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {}))
Expand Down
5 changes: 4 additions & 1 deletion bayesflow/diagnostics/plots/mc_calibration.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def mc_calibration(
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
markersize: float = None,
) -> plt.Figure:
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
Expand Down Expand Up @@ -60,6 +61,8 @@ def mc_calibration(
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
markersize : float, optional, default: None
The marker size in points.

Returns
-------
Expand Down Expand Up @@ -88,7 +91,7 @@ def mc_calibration(

for j, ax in enumerate(plot_data["axes"].flat):
# Plot calibration curve
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color)
ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color, markersize=markersize)

# Plot PMP distribution over bins
uniform_bins = np.linspace(0.0, 1.0, num_bins + 1)
Expand Down
12 changes: 11 additions & 1 deletion bayesflow/diagnostics/plots/pairs_posterior.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ def pairs_posterior(
prior_color: str | tuple = "gray",
target_color: str | tuple = "red",
alpha: float = 0.9,
markersize: float = 40,
target_markersize: float = 40,
label_fontsize: int = 14,
tick_fontsize: int = 12,
legend_fontsize: int = 14,
Expand Down Expand Up @@ -62,6 +64,10 @@ def pairs_posterior(
The color for the optional true parameter lines and points
alpha : float in [0, 1], optional, default: 0.9
The opacity of the posterior plots
markersize : float, optional, default: 40
The marker size in points**2 of the scatter plots
target_markersize : float, optional, default: 40
The marker size in points**2 of the target marker

**kwargs : dict, optional, default: {}
Further optional keyword arguments propagated to `_pairs_samples`
Expand Down Expand Up @@ -101,6 +107,9 @@ def pairs_posterior(
label_fontsize=label_fontsize,
tick_fontsize=tick_fontsize,
legend_fontsize=legend_fontsize,
markersize=markersize,
target_markersize=target_markersize,
target_color=target_color,
**kwargs,
)

Expand All @@ -114,7 +123,7 @@ def pairs_posterior(
g.data = pd.DataFrame(targets, columns=targets.variable_names)
g.data["_source"] = "True Parameter"
g.map_diag(plot_true_params_as_lines, color=target_color)
g.map_offdiag(plot_true_params_as_points, color=target_color)
g.map_offdiag(plot_true_params_as_points, color=target_color, s=target_markersize)

create_legends(
g,
Expand All @@ -124,6 +133,7 @@ def pairs_posterior(
legend_fontsize=legend_fontsize,
show_single_legend=False,
target_color=target_color,
target_markersize=target_markersize,
)

return g
Expand Down
29 changes: 27 additions & 2 deletions bayesflow/diagnostics/plots/pairs_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

def pairs_samples(
samples: Mapping[str, np.ndarray] | np.ndarray = None,
dataset_id: int = None,
variable_keys: Sequence[str] = None,
variable_names: Sequence[str] = None,
height: float = 2.5,
Expand All @@ -22,6 +23,7 @@ def pairs_samples(
label_fontsize: int = 14,
tick_fontsize: int = 12,
show_single_legend: bool = False,
markersize: float = 40,
**kwargs,
) -> sns.PairGrid:
"""
Expand All @@ -32,6 +34,8 @@ def pairs_samples(
----------
samples : dict[str, Tensor], default: None
Sample draws from any dataset
dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
Should only be specified if estimates contain posterior draws from multiple datasets.
variable_keys : list or None, optional, default: None
Select keys from the dictionary provided in samples.
By default, select all keys.
Expand All @@ -52,15 +56,23 @@ def pairs_samples(
show_single_legend : bool, optional, default: False
Optional toggle for the user to choose whether a single dataset
should also display legend
markersize : float, optional, default: 40
Marker size in points**2 of the scatter plot.
**kwargs : dict, optional
Additional keyword arguments passed to the sns.PairGrid constructor
"""

plot_data = dicts_to_arrays(
estimates=samples,
dataset_ids=dataset_id,
variable_keys=variable_keys,
variable_names=variable_names,
)
# dicts_to_arrays will keep the dataset axis even if it is of length 1
# however, pairs plotting requires the dataset axis to be removed
estimates_shape = plot_data["estimates"].shape
if len(estimates_shape) == 3 and estimates_shape[0] == 1:
plot_data["estimates"] = np.squeeze(plot_data["estimates"], axis=0)

g = _pairs_samples(
plot_data=plot_data,
Expand All @@ -71,6 +83,7 @@ def pairs_samples(
label_fontsize=label_fontsize,
tick_fontsize=tick_fontsize,
show_single_legend=show_single_legend,
markersize=markersize,
**kwargs,
)

Expand All @@ -88,6 +101,9 @@ def _pairs_samples(
tick_fontsize: int = 12,
legend_fontsize: int = 14,
show_single_legend: bool = False,
markersize: float = 40,
target_markersize: float = 40,
target_color: str = "red",
**kwargs,
) -> sns.PairGrid:
"""
Expand All @@ -101,6 +117,12 @@ def _pairs_samples(
color2 : str, optional, default: 'gray'
Secondary color for the pair plots.
This is the color used for the prior draws.
markersize : float, optional, default: 40
Marker size in points**2 of the scatter plot.
target_markersize : float, optional, default: 40
Target marker size in points**2 of the scatter plot.
target_color : str, optional, default: "red"
Target marker color for the legend.

Other arguments are documented in pairs_samples
"""
Expand Down Expand Up @@ -159,14 +181,14 @@ def _pairs_samples(
)

# add scatter plots to the upper diagonal
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
g.map_upper(sns.scatterplot, alpha=0.6, s=markersize, edgecolor="k", color=color, lw=0)

# add KDEs to the lower diagonal
try:
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha, common_norm=False)
except Exception as e:
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)
g.map_lower(sns.scatterplot, alpha=0.6, s=markersize, edgecolor="k", color=color, lw=0)

# Generate grids
dim = g.axes.shape[0]
Expand Down Expand Up @@ -200,6 +222,9 @@ def _pairs_samples(
legend_fontsize=legend_fontsize,
label=label,
show_single_legend=show_single_legend,
markersize=markersize,
target_markersize=target_markersize,
target_color=target_color,
)

# Return figure
Expand Down
17 changes: 14 additions & 3 deletions bayesflow/diagnostics/plots/recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def recovery(
num_row: int = None,
xlabel: str = "Ground truth",
ylabel: str = "Estimate",
markersize: float = None,
**kwargs,
) -> plt.Figure:
"""
Expand Down Expand Up @@ -76,8 +77,10 @@ def recovery(
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
xlabel:
ylabel:
xlabel :
ylabel :
markersize : float, optional, default: None
The marker size in points.

Returns
-------
Expand Down Expand Up @@ -122,10 +125,18 @@ def recovery(
fmt="o",
alpha=0.5,
color=color,
markersize=markersize,
**kwargs,
)
else:
_ = ax.scatter(targets[:, i], point_estimate[:, i], alpha=0.5, color=color, **kwargs)
_ = ax.scatter(
targets[:, i],
point_estimate[:, i],
alpha=0.5,
color=color,
s=None if markersize is None else markersize**2,
**kwargs,
)

make_quadratic(ax, targets[:, i], point_estimate[:, i])

Expand Down
8 changes: 6 additions & 2 deletions bayesflow/diagnostics/plots/recovery_from_estimates.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def recovery_from_estimates(
num_row: int = None,
xlabel: str = "Ground truth",
ylabel: str = "Estimate",
markersize: float = None,
**kwargs,
) -> plt.Figure:
"""
Expand Down Expand Up @@ -79,8 +80,10 @@ def recovery_from_estimates(
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
xlabel:
ylabel:
xlabel :
ylabel :
markersize : float, optional, default: None
The marker size in points.

Returns
-------
Expand Down Expand Up @@ -139,6 +142,7 @@ def recovery_from_estimates(
marker=markers[q_idx],
alpha=0.5,
color=color,
s=None if markersize is None else markersize**2,
**kwargs,
)

Expand Down
5 changes: 4 additions & 1 deletion bayesflow/diagnostics/plots/z_score_contraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def z_score_contraction(
color: str = "#132a70",
num_col: int = None,
num_row: int = None,
markersize: float = None,
) -> plt.Figure:
"""
Implements a graphical check for global model sensitivity by plotting the
Expand Down Expand Up @@ -76,6 +77,8 @@ def z_score_contraction(
The number of rows for the subplots. Dynamically determined if None.
num_col : int, optional, default: None
The number of columns for the subplots. Dynamically determined if None.
markersize : float, optional, default: None
The marker size in points**2 of the scatter plot.

Returns
-------
Expand Down Expand Up @@ -118,7 +121,7 @@ def z_score_contraction(
if i >= plot_data["num_variables"]:
break

ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5)
ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5, s=markersize)
ax.set_xlim([-0.05, 1.05])

prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize)
Expand Down
18 changes: 16 additions & 2 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ def create_legends(
label: str = "Posterior",
show_single_legend: bool = False,
legend_fontsize: int = 14,
markersize: float = 40,
target_color: str = "red",
target_markersize: float = 40,
):
"""
Helper function to create legends for pairplots.
Expand All @@ -396,8 +398,12 @@ def create_legends(
should also display legend
legend_fontsize : int, optional, default: 14
fontsize for the legend
target_color : str, optional, default "red"
markersize : float, optional, default: 40
The marker size in points**2
target_color : str, optional, default: "red"
Color for the target label
target_markersize : float, optional, default: 40
Marker size in points**2 of the target marker
"""
handles = []
labels = []
Expand All @@ -414,7 +420,15 @@ def create_legends(
labels.append(posterior_label)

if plot_data.get("targets") is not None:
target_handle = plt.Line2D([0], [0], color=target_color, linestyle="--", marker="x", label="Targets")
target_handle = plt.Line2D(
[0],
[0],
color=target_color,
linestyle="--",
marker="x",
markersize=np.sqrt(target_markersize),
label="Targets",
)
target_label = "Targets"
handles.append(target_handle)
labels.append(target_label)
Expand Down
Loading
Loading