Skip to content

Commit 800ed2c

Browse files
committed
Ensure consistent spacing between plot and legends + cleanup
1 parent bcba960 commit 800ed2c

File tree

3 files changed

+12
-63
lines changed

3 files changed

+12
-63
lines changed

bayesflow/diagnostics/plots/pairs_posterior.py

Lines changed: 0 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -118,47 +118,6 @@ def pairs_posterior(
118118
g, plot_data, color=post_color, color2=prior_color, fontsize=legend_fontsize, show_single_legend=False
119119
)
120120

121-
# target_handle = plt.Line2D(
122-
# [0], [0],
123-
# color=target_color,
124-
# linestyle="--",
125-
# marker="x",
126-
# label="Targets",
127-
# )
128-
#
129-
# for ax in g.axes.flat:
130-
# if getattr(ax, "legend_", None) is not None:
131-
# ax.legend_.remove()
132-
#
133-
# diag_ax = g.axes[0, 0]
134-
# # Collect histogram legend handles
135-
# # hist_handles = getattr(diag_ax, '_legend_handles', [])
136-
#
137-
# hist_handles, hist_labels = diag_ax.get_legend_handles_labels()
138-
#
139-
# handles = []
140-
# labels = []
141-
#
142-
# for handle, label in zip(hist_handles, hist_labels):
143-
#
144-
# if label != "Targets":
145-
# handles.append(handle)
146-
# labels.append(label)
147-
#
148-
# handles.append(target_handle)
149-
# labels.append(target_handle.get_label())
150-
#
151-
# # handles = hist_handles + [target_handle]
152-
# # labels = hist_labels + [target_handle.get_label()] # safer to refresh labels
153-
#
154-
# g.figure.legend(
155-
# handles=handles,
156-
# labels=labels,
157-
# loc="center right",
158-
# frameon=False,
159-
# fontsize=legend_fontsize
160-
# )
161-
162121
return g
163122

164123

bayesflow/diagnostics/plots/pairs_samples.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -194,33 +194,22 @@ def _pairs_samples(
194194
return g
195195

196196

197-
# create a histogram plot on a twin y axis
198-
# this ensures that the y scaling of the diagonal plots
199-
# in independent of the y scaling of the off-diagonal plots
200197
def histplot_twinx(x, **kwargs):
201-
# Create a twin axis
202-
# ax2 = plt.gca().twinx()
203-
204-
label = kwargs.pop("labels", None)
205-
color = kwargs.get("colors", None)
206-
207-
ax = plt.gca()
198+
"""
199+
# create a histogram plot on a twin y axis
200+
# this ensures that the y scaling of the diagonal plots
201+
# in independent of the y scaling of the off-diagonal plots
208202
203+
Parameters
204+
----------
205+
x : np.ndarray
206+
Data to be plotted.
207+
"""
209208
# create a histogram on the twin axis
210209
sns.histplot(x, legend=False, **kwargs)
211210

212-
# if label is not None:
213-
# legend_artist = Patch(color=color, label=label)
214-
# # Store the artist for later
215-
# if not hasattr(ax, '_legend_handles'):
216-
# ax._legend_handles = []
217-
# ax._legend_handles.append(legend_artist)
218-
219211
# make the twin axis invisible
220212
plt.gca().spines["right"].set_visible(False)
221213
plt.gca().spines["top"].set_visible(False)
222-
# ax2.set_ylabel("")
223-
# ax2.set_yticks([])
224-
# ax2.set_yticklabels([])
225214

226215
return None

bayesflow/utils/plot_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,12 +383,13 @@ def create_legends(
383383
handles.append(target_handle)
384384
labels.append(target_label)
385385

386+
# If there are more than one dataset to plot,
386387
if len(handles) > 1 or show_single_legend:
387388
g.figure.legend(
388389
handles=handles,
389390
labels=labels,
390-
loc="center right",
391-
bbox_to_anchor=(1.2, 0.5),
391+
loc="center left",
392+
bbox_to_anchor=(1, 0.5),
392393
frameon=False,
393394
fontsize=fontsize,
394395
)

0 commit comments

Comments
 (0)