Skip to content

Commit 4e9d475

Browse files
committed
Fix legend bug in plot_losses without validation loss
1 parent f806091 commit 4e9d475

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

bayesflow/diagnostics.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def plot_recovery(
7171
The posterior draws obtained from n_data_sets
7272
prior_samples : np.ndarray of shape (n_data_sets, n_params)
7373
The prior draws (true parameters) obtained for generating the n_data_sets
74-
point_agg : callable, optional, default: np.median
74+
point_agg : callable, optional, default: ``np.median``
7575
The function to apply to the posterior draws to get a point estimate for each marginal.
7676
The default computes the marginal median for each marginal posterior as a robust
7777
point estimate.
@@ -89,7 +89,7 @@ def plot_recovery(
8989
metric_fontsize : int, optional, default: 16
9090
The font size of the goodness-of-fit metric (if provided)
9191
tick_fontsize : int, optional, default: 12
92-
The font size of the axis ticklabels
92+
The font size of the axis tick labels
9393
add_corr : bool, optional, default: True
9494
A flag for adding correlation between true and estimates to the plot
9595
add_r2 : bool, optional, default: True
@@ -242,7 +242,7 @@ def plot_z_score_contraction(
242242
243243
post_contraction = 1 - (posterior_variance / prior_variance)
244244
245-
In other words, the posterior is a proxy for the reduction in ucnertainty gained by
245+
In other words, the posterior is a proxy for the reduction in uncertainty gained by
246246
replacing the prior with the posterior. The ideal posterior contraction tends to 1.
247247
Contraction near zero indicates that the posterior variance is almost identical to
248248
the prior variance for the particular marginal parameter distribution.
@@ -894,7 +894,7 @@ def plot_losses(
894894
ax.grid(alpha=grid_alpha)
895895
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
896896
# Only add legend if there is a validation curve
897-
if val_losses is not None:
897+
if val_losses is not None or moving_average:
898898
ax.legend(fontsize=legend_fontsize)
899899
f.tight_layout()
900900
return f

0 commit comments

Comments
 (0)