Skip to content

Commit f806091

Browse files
committed
Update 2d scatterplots to work with recent seaborn
1 parent 126d354 commit f806091

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

bayesflow/diagnostics.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def plot_recovery(
144144
if i >= n_params:
145145
break
146146

147-
# Add scatter and errorbars
147+
# Add scatter and error bars
148148
if uncertainty_agg is not None:
149149
_ = ax.errorbar(prior_samples[:, i], est[:, i], yerr=u[:, i], fmt="o", alpha=0.5, color=color)
150150
else:
@@ -950,9 +950,9 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
950950
try:
951951
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9)
952952
except Exception as e:
953-
logging.warn("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
954-
g.map_lower(plt.scatter, alpha=0.6, s=40, edgecolor="k", color=color)
955-
g.map_upper(plt.scatter, alpha=0.6, s=40, edgecolor="k", color=color)
953+
logging.warning("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
954+
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color)
955+
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color)
956956

957957
# Add grids
958958
for i in range(dim):
@@ -963,7 +963,7 @@ def plot_prior2d(prior, param_names=None, n_samples=2000, height=2.5, color="#8f
963963

964964

965965
def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
966-
"""Creates pairplots for the latent space learned by the inference network. Enables
966+
"""Creates pair plots for the latent space learned by the inference network. Enables
967967
visual inspection of the latent space and whether its structure corresponds to the
968968
one enforced by the optimization criterion.
969969
@@ -998,7 +998,7 @@ def plot_latent_space_2d(z_samples, height=2.5, color="#8f2727", **kwargs):
998998
g = sns.PairGrid(data_to_plot, height=height, **kwargs)
999999
g.map_diag(sns.histplot, fill=True, color=color, alpha=0.9, kde=True)
10001000
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=0.9)
1001-
g.map_upper(plt.scatter, alpha=0.6, s=40, edgecolor="k", color=color)
1001+
g.map_upper(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color)
10021002

10031003
# Add grids
10041004
for i in range(z_dim):

0 commit comments

Comments
 (0)