Skip to content

Commit 8567049

Browse files
committed
minor fixes and improvements to the pairs plot functions
- pass target color to legend - do not use common norm, so that prior stays visible in kdeplots - do not share y on the diagonal, so that all marginal distributions stay visible, even if one is very peaked
1 parent 7cabf17 commit 8567049

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def pairs_posterior(
123123
color2=prior_color,
124124
legend_fontsize=legend_fontsize,
125125
show_single_legend=False,
126+
target_color=target_color,
126127
)
127128

128129
return g

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ def _pairs_samples(
133133
height=height,
134134
hue="_source",
135135
palette=[color2, color],
136+
diag_sharey=False,
136137
**kwargs,
137138
)
138139

@@ -162,7 +163,7 @@ def _pairs_samples(
162163

163164
# add KDEs to the lower diagonal
164165
try:
165-
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha)
166+
g.map_lower(sns.kdeplot, fill=True, color=color, alpha=alpha, common_norm=False)
166167
except Exception as e:
167168
logging.exception("KDE failed due to the following exception:\n" + repr(e) + "\nSubstituting scatter plot.")
168169
g.map_lower(sns.scatterplot, alpha=0.6, s=40, edgecolor="k", color=color, lw=0)

bayesflow/utils/plot_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def create_legends(
374374
label: str = "Posterior",
375375
show_single_legend: bool = False,
376376
legend_fontsize: int = 14,
377+
target_color: str = "red",
377378
):
378379
"""
379380
Helper function to create legends for pairplots.
@@ -395,6 +396,8 @@ def create_legends(
395396
should also display legend
396397
legend_fontsize : int, optional, default: 14
397398
fontsize for the legend
399+
target_color : str, optional, default "red"
400+
Color for the target label
398401
"""
399402
handles = []
400403
labels = []
@@ -411,7 +414,7 @@ def create_legends(
411414
labels.append(posterior_label)
412415

413416
if plot_data.get("targets") is not None:
414-
target_handle = plt.Line2D([0], [0], color="r", linestyle="--", marker="x", label="Targets")
417+
target_handle = plt.Line2D([0], [0], color=target_color, linestyle="--", marker="x", label="Targets")
415418
target_label = "Targets"
416419
handles.append(target_handle)
417420
labels.append(target_label)

0 commit comments

Comments
 (0)