Skip to content

Commit 4d429ae

Browse files
authored
set_label names in plot_recovery()
qol change to allow to change label names if wanting something other than "ground truth" / "estimated"
1 parent 2bd8744 commit 4d429ae

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

bayesflow/diagnostics.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def plot_recovery(
5151
color="#8f2727",
5252
n_col=None,
5353
n_row=None,
54+
xlabel="Ground truth",
55+
ylabel="Estimated",
5456
):
5557
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
5658
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
@@ -198,15 +200,15 @@ def plot_recovery(
198200
# Only add x-labels to the bottom row
199201
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
200202
for _ax in bottom_row:
201-
_ax.set_xlabel("Ground truth", fontsize=label_fontsize)
203+
_ax.set_xlabel(xlabel, fontsize=label_fontsize)
202204

203205
# Only add y-labels to right left-most row
204206
if n_row == 1: # if there is only one row, the ax array is 1D
205-
axarr[0].set_ylabel("Estimated", fontsize=label_fontsize)
207+
axarr[0].set_ylabel(ylabel, fontsize=label_fontsize)
206208
# If there is more than one row, the ax array is 2D
207209
else:
208210
for _ax in axarr[:, 0]:
209-
_ax.set_ylabel("Estimated", fontsize=label_fontsize)
211+
_ax.set_ylabel(ylabel, fontsize=label_fontsize)
210212

211213
# Remove unused axes entirely
212214
for _ax in axarr_it[n_params:]:

0 commit comments

Comments
 (0)