Skip to content

Commit c76304a

Browse files
Merge pull request #102 from levolz/master
set label names in plot_recovery()
2 parents a350ec1 + 876f09d commit c76304a

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

bayesflow/diagnostics.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ def plot_recovery(
5151
color="#8f2727",
5252
n_col=None,
5353
n_row=None,
54+
xlabel="Ground truth",
55+
ylabel="Estimated",
5456
):
5557
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
5658
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
@@ -96,7 +98,11 @@ def plot_recovery(
9698
A flag for adding R^2 between true and estimates to the plot
9799
color : str, optional, default: '#8f2727'
98100
The color for the true vs. estimated scatter points and error bars
99-
101+
xlabel : str, optional, default: 'Ground truth'
102+
The label on the x-axis of the plot
103+
ylabel : str, optional, default: 'Estimated'
104+
The label on the y-axis of the plot
105+
100106
Returns
101107
-------
102108
f : plt.Figure - the figure instance for optional saving
@@ -198,15 +204,15 @@ def plot_recovery(
198204
# Only add x-labels to the bottom row
199205
bottom_row = axarr if n_row == 1 else axarr[0] if n_col == 1 else axarr[n_row - 1, :]
200206
for _ax in bottom_row:
201-
_ax.set_xlabel("Ground truth", fontsize=label_fontsize)
207+
_ax.set_xlabel(xlabel, fontsize=label_fontsize)
202208

203209
# Only add y-labels to right left-most row
204210
if n_row == 1: # if there is only one row, the ax array is 1D
205-
axarr[0].set_ylabel("Estimated", fontsize=label_fontsize)
211+
axarr[0].set_ylabel(ylabel, fontsize=label_fontsize)
206212
# If there is more than one row, the ax array is 2D
207213
else:
208214
for _ax in axarr[:, 0]:
209-
_ax.set_ylabel("Estimated", fontsize=label_fontsize)
215+
_ax.set_ylabel(ylabel, fontsize=label_fontsize)
210216

211217
# Remove unused axes entirely
212218
for _ax in axarr_it[n_params:]:

0 commit comments

Comments
 (0)