diff --git a/bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py b/bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py index 938ba9d8f..6229dcdf9 100644 --- a/bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py +++ b/bayesflow/diagnostics/plots/calibration_ecdf_from_quantiles.py @@ -26,6 +26,7 @@ def calibration_ecdf_from_quantiles( fill_color: str = "grey", num_row: int = None, num_col: int = None, + markersize: float = None, **kwargs, ) -> plt.Figure: """ @@ -97,6 +98,8 @@ def calibration_ecdf_from_quantiles( num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. + markersize : float, optional, default: None + The marker size in points. **kwargs : dict, optional, default: {} Keyword arguments can be passed to control the behavior of ECDF simultaneous band computation through the ``ecdf_bands_kwargs`` @@ -142,11 +145,15 @@ def calibration_ecdf_from_quantiles( if stacked: if j == 0: - plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDFs") + plot_data["axes"][0].plot( + xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDFs" + ) else: - plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95) + plot_data["axes"][0].plot(xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95) else: - plot_data["axes"].flat[j].plot(xx, yy, marker="o", color=rank_ecdf_color, alpha=0.95, label="Rank ECDF") + plot_data["axes"].flat[j].plot( + xx, yy, marker="o", color=rank_ecdf_color, markersize=markersize, alpha=0.95, label="Rank ECDF" + ) # Compute uniform ECDF and bands alpha, z, L, U = pointwise_ecdf_bands(estimates.shape[0], **kwargs.pop("ecdf_bands_kwargs", {})) diff --git a/bayesflow/diagnostics/plots/mc_calibration.py b/bayesflow/diagnostics/plots/mc_calibration.py index 0377d9847..f3ab19e6c 100644 --- a/bayesflow/diagnostics/plots/mc_calibration.py +++ b/bayesflow/diagnostics/plots/mc_calibration.py @@ -27,6 +27,7 @@ def mc_calibration( color: str = "#132a70", num_col: int = None, num_row: int = None, + markersize: float = None, ) -> plt.Figure: """Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin. @@ -60,6 +61,8 @@ def mc_calibration( The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. + markersize : float, optional, default: None + The marker size in points. Returns ------- @@ -88,7 +91,7 @@ def mc_calibration( for j, ax in enumerate(plot_data["axes"].flat): # Plot calibration curve - ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color) + ax.plot(ece["probs_pred"][j], ece["probs_true"][j], "o-", color=color, markersize=markersize) # Plot PMP distribution over bins uniform_bins = np.linspace(0.0, 1.0, num_bins + 1) diff --git a/bayesflow/diagnostics/plots/pairs_posterior.py b/bayesflow/diagnostics/plots/pairs_posterior.py index d3408771c..9dd449e15 100644 --- a/bayesflow/diagnostics/plots/pairs_posterior.py +++ b/bayesflow/diagnostics/plots/pairs_posterior.py @@ -24,6 +24,8 @@ def pairs_posterior( prior_color: str | tuple = "gray", target_color: str | tuple = "red", alpha: float = 0.9, + markersize: float = 40, + target_markersize: float = 40, label_fontsize: int = 14, tick_fontsize: int = 12, legend_fontsize: int = 14, @@ -62,6 +64,10 @@ def pairs_posterior( The color for the optional true parameter lines and points alpha : float in [0, 1], optional, default: 0.9 The opacity of the posterior plots + markersize : float, optional, default: 40 + The marker size in points**2 of the scatter plots + target_markersize : float, optional, default: 40 + The marker size in points**2 of the target marker **kwargs : dict, optional, default: {} Further optional keyword arguments propagated to `_pairs_samples` @@ -101,6 +107,9 @@ def pairs_posterior( label_fontsize=label_fontsize, tick_fontsize=tick_fontsize, legend_fontsize=legend_fontsize, + markersize=markersize, + target_markersize=target_markersize, + target_color=target_color, **kwargs, ) @@ -114,7 +123,7 @@ def pairs_posterior( g.data = pd.DataFrame(targets, columns=targets.variable_names) g.data["_source"] = "True Parameter" g.map_diag(plot_true_params_as_lines, color=target_color) - g.map_offdiag(plot_true_params_as_points, color=target_color) + g.map_offdiag(plot_true_params_as_points, color=target_color, s=target_markersize) create_legends( g, @@ -124,6 +133,7 @@ def pairs_posterior( legend_fontsize=legend_fontsize, show_single_legend=False, target_color=target_color, + target_markersize=target_markersize, ) return g diff --git a/bayesflow/diagnostics/plots/pairs_samples.py b/bayesflow/diagnostics/plots/pairs_samples.py index 4979e8d39..30a3f2f95 100644 --- a/bayesflow/diagnostics/plots/pairs_samples.py +++ b/bayesflow/diagnostics/plots/pairs_samples.py @@ -13,6 +13,7 @@ def pairs_samples( samples: Mapping[str, np.ndarray] | np.ndarray = None, + dataset_id: int = None, variable_keys: Sequence[str] = None, variable_names: Sequence[str] = None, height: float = 2.5, @@ -22,6 +23,7 @@ def pairs_samples( label_fontsize: int = 14, tick_fontsize: int = 12, show_single_legend: bool = False, + markersize: float = 40, **kwargs, ) -> sns.PairGrid: """ @@ -32,6 +34,8 @@ def pairs_samples( ---------- samples : dict[str, Tensor], default: None Sample draws from any dataset + dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated. + Should only be specified if estimates contain posterior draws from multiple datasets. variable_keys : list or None, optional, default: None Select keys from the dictionary provided in samples. By default, select all keys. @@ -52,15 +56,23 @@ def pairs_samples( show_single_legend : bool, optional, default: False Optional toggle for the user to choose whether a single dataset should also display legend + markersize : float, optional, default: 40 + Marker size in points**2 of the scatter plot. **kwargs : dict, optional Additional keyword arguments passed to the sns.PairGrid constructor """ plot_data = dicts_to_arrays( estimates=samples, + dataset_ids=dataset_id, variable_keys=variable_keys, variable_names=variable_names, ) + # dicts_to_arrays will keep the dataset axis even if it is of length 1 + # however, pairs plotting requires the dataset axis to be removed + estimates_shape = plot_data["estimates"].shape + if len(estimates_shape) == 3 and estimates_shape[0] == 1: + plot_data["estimates"] = np.squeeze(plot_data["estimates"], axis=0) g = _pairs_samples( plot_data=plot_data, @@ -71,6 +83,7 @@ def pairs_samples( label_fontsize=label_fontsize, tick_fontsize=tick_fontsize, show_single_legend=show_single_legend, + markersize=markersize, **kwargs, ) @@ -88,6 +101,9 @@ def _pairs_samples( tick_fontsize: int = 12, legend_fontsize: int = 14, show_single_legend: bool = False, + markersize: float = 40, + target_markersize: float = 40, + target_color: str = "red", **kwargs, ) -> sns.PairGrid: """ @@ -101,6 +117,12 @@ def _pairs_samples( color2 : str, optional, default: 'gray' Secondary color for the pair plots. This is the color used for the prior draws. + markersize : float, optional, default: 40 + Marker size in points**2 of the scatter plot. + target_markersize : float, optional, default: 40 + Target marker size in points**2 of the scatter plot. + target_color : str, optional, default: "red" + Target marker color for the legend. Other arguments are documented in pairs_samples """ @@ -159,14 +181,14 @@ def _pairs_samples( ) # add scatter plots to the upper diagonal - g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) + g.map_upper(sns.scatterplot, alpha=0.6, s=markersize, edgecolor="k", color=color, lw=0) # add KDEs to the lower diagonal try: g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha, common_norm=False) except Exception as e: logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.") - g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0) + g.map_lower(sns.scatterplot, alpha=0.6, s=markersize, edgecolor="k", color=color, lw=0) # Generate grids dim = g.axes.shape[0] @@ -200,6 +222,9 @@ def _pairs_samples( legend_fontsize=legend_fontsize, label=label, show_single_legend=show_single_legend, + markersize=markersize, + target_markersize=target_markersize, + target_color=target_color, ) # Return figure diff --git a/bayesflow/diagnostics/plots/recovery.py b/bayesflow/diagnostics/plots/recovery.py index a29f58abd..462f06546 100644 --- a/bayesflow/diagnostics/plots/recovery.py +++ b/bayesflow/diagnostics/plots/recovery.py @@ -26,6 +26,7 @@ def recovery( num_row: int = None, xlabel: str = "Ground truth", ylabel: str = "Estimate", + markersize: float = None, **kwargs, ) -> plt.Figure: """ @@ -76,8 +77,10 @@ def recovery( The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. - xlabel: - ylabel: + xlabel : + ylabel : + markersize : float, optional, default: None + The marker size in points. Returns ------- @@ -122,10 +125,18 @@ def recovery( fmt="o", alpha=0.5, color=color, + markersize=markersize, **kwargs, ) else: - _ = ax.scatter(targets[:, i], point_estimate[:, i], alpha=0.5, color=color, **kwargs) + _ = ax.scatter( + targets[:, i], + point_estimate[:, i], + alpha=0.5, + color=color, + s=None if markersize is None else markersize**2, + **kwargs, + ) make_quadratic(ax, targets[:, i], point_estimate[:, i]) diff --git a/bayesflow/diagnostics/plots/recovery_from_estimates.py b/bayesflow/diagnostics/plots/recovery_from_estimates.py index 4aacb0350..403206474 100644 --- a/bayesflow/diagnostics/plots/recovery_from_estimates.py +++ b/bayesflow/diagnostics/plots/recovery_from_estimates.py @@ -25,6 +25,7 @@ def recovery_from_estimates( num_row: int = None, xlabel: str = "Ground truth", ylabel: str = "Estimate", + markersize: float = None, **kwargs, ) -> plt.Figure: """ @@ -79,8 +80,10 @@ def recovery_from_estimates( The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. - xlabel: - ylabel: + xlabel : + ylabel : + markersize : float, optional, default: None + The marker size in points. Returns ------- @@ -139,6 +142,7 @@ def recovery_from_estimates( marker=markers[q_idx], alpha=0.5, color=color, + s=None if markersize is None else markersize**2, **kwargs, ) diff --git a/bayesflow/diagnostics/plots/z_score_contraction.py b/bayesflow/diagnostics/plots/z_score_contraction.py index 702fd7d9c..c013889a0 100644 --- a/bayesflow/diagnostics/plots/z_score_contraction.py +++ b/bayesflow/diagnostics/plots/z_score_contraction.py @@ -18,6 +18,7 @@ def z_score_contraction( color: str = "#132a70", num_col: int = None, num_row: int = None, + markersize: float = None, ) -> plt.Figure: """ Implements a graphical check for global model sensitivity by plotting the @@ -76,6 +77,8 @@ def z_score_contraction( The number of rows for the subplots. Dynamically determined if None. num_col : int, optional, default: None The number of columns for the subplots. Dynamically determined if None. + markersize : float, optional, default: None + The marker size in points**2 of the scatter plot. Returns ------- @@ -118,7 +121,7 @@ def z_score_contraction( if i >= plot_data["num_variables"]: break - ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5) + ax.scatter(contraction[:, i], z_score[:, i], color=color, alpha=0.5, s=markersize) ax.set_xlim([-0.05, 1.05]) prettify_subplots(plot_data["axes"], num_subplots=plot_data["num_variables"], tick_fontsize=tick_fontsize) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 0389c029d..398d2d970 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -374,7 +374,9 @@ def create_legends( label: str = "Posterior", show_single_legend: bool = False, legend_fontsize: int = 14, + markersize: float = 40, target_color: str = "red", + target_markersize: float = 40, ): """ Helper function to create legends for pairplots. @@ -396,8 +398,12 @@ def create_legends( should also display legend legend_fontsize : int, optional, default: 14 fontsize for the legend - target_color : str, optional, default "red" + markersize : float, optional, default: 40 + The marker size in points**2 + target_color : str, optional, default: "red" Color for the target label + target_markersize : float, optional, default: 40 + Marker size in points**2 of the target marker """ handles = [] labels = [] @@ -414,7 +420,15 @@ def create_legends( labels.append(posterior_label) if plot_data.get("targets") is not None: - target_handle = plt.Line2D([0], [0], color=target_color, linestyle="--", marker="x", label="Targets") + target_handle = plt.Line2D( + [0], + [0], + color=target_color, + linestyle="--", + marker="x", + markersize=np.sqrt(target_markersize), + label="Targets", + ) target_label = "Targets" handles.append(target_handle) labels.append(target_label) diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index 8d4b7883b..952fe4002 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -66,6 +66,19 @@ def test_calibration_ecdf(random_estimates, random_targets, var_names): assert out.axes[-1].title._text == r"$\sigma$" +def test_calibration_ecdf_from_quantiles(random_estimates, random_targets, var_names): + quantile_levels = [0.1, 0.5, 0.9] + + estimates = { + variable_name: {"quantiles": np.moveaxis(np.quantile(value, q=quantile_levels, axis=1), 0, 1)} + for variable_name, value in random_estimates.items() + } + + out = bf.diagnostics.calibration_ecdf_from_quantiles(estimates, random_targets, quantile_levels=quantile_levels) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[1].title._text == "beta_1" + + def test_calibration_histogram(random_estimates, random_targets): # basic functionality: automatic variable names out = bf.diagnostics.plots.calibration_histogram(random_estimates, random_targets) @@ -81,14 +94,25 @@ def test_loss(history): def test_recovery(random_estimates, random_targets): # basic functionality: automatic variable names - out = bf.diagnostics.plots.recovery(random_estimates, random_targets) + out = bf.diagnostics.plots.recovery(random_estimates, random_targets, markersize=4) + assert len(out.axes) == num_variables(random_estimates) + assert out.axes[2].title._text == "sigma" + + +def test_recovery_from_estimates(random_estimates, random_targets): + # basic functionality: automatic variable names + estimates = {variable_name: {"mean": np.mean(value, axis=1)} for variable_name, value in random_estimates.items()} + + out = bf.diagnostics.plots.recovery_from_estimates( + estimates, random_targets, markersize=4, marker_mapping={"mean": "x"} + ) assert len(out.axes) == num_variables(random_estimates) assert out.axes[2].title._text == "sigma" def test_z_score_contraction(random_estimates, random_targets): # basic functionality: automatic variable names - out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets) + out = bf.diagnostics.plots.z_score_contraction(random_estimates, random_targets, markersize=4) assert len(out.axes) == num_variables(random_estimates) assert out.axes[1].title._text == "beta_1" @@ -97,6 +121,7 @@ def test_pairs_samples(random_priors): out = bf.diagnostics.plots.pairs_samples( samples=random_priors, variable_keys=["beta", "sigma"], + markersize=4, ) num_vars = random_priors["sigma"].shape[-1] + random_priors["beta"].shape[-1] assert out.axes.shape == (num_vars, num_vars) @@ -107,9 +132,7 @@ def test_pairs_samples(random_priors): def test_pairs_posterior(random_estimates, random_targets, random_priors): # basic functionality: automatic variable names out = bf.diagnostics.plots.pairs_posterior( - random_estimates, - random_targets, - dataset_id=1, + random_estimates, random_targets, dataset_id=1, markersize=4, target_markersize=4 ) num_vars = num_variables(random_estimates) assert out.axes.shape == (num_vars, num_vars) @@ -139,7 +162,7 @@ def test_pairs_posterior(random_estimates, random_targets, random_priors): def test_mc_calibration(pred_models, true_models, model_names): - out = bf.diagnostics.plots.mc_calibration(pred_models, true_models, model_names=model_names) + out = bf.diagnostics.plots.mc_calibration(pred_models, true_models, model_names=model_names, markersize=4) assert len(out.axes) == pred_models.shape[-1] assert out.axes[0].get_ylabel() == "True Probability" assert out.axes[0].get_xlabel() == "Predicted Probability"