Skip to content

Commit f1f19fd

Browse files
committed
Improve plot_losses y-axis clarity
1 parent 28db526 commit f1f19fd

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

bayesflow/diagnostics.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def plot_recovery(
5353
n_row=None,
5454
xlabel="Ground truth",
5555
ylabel="Estimated",
56-
**kwargs
56+
**kwargs,
5757
):
5858
"""Creates and plots publication-ready recovery plot with true vs. point estimate + uncertainty.
5959
The point estimate can be controlled with the ``point_agg`` argument, and the uncertainty estimate
@@ -110,7 +110,7 @@ def plot_recovery(
110110
**kwargs : optional
111111
Additional keyword arguments passed to ax.errorbar or ax.scatter.
112112
Example: `rasterized=True` to reduce PDF file size with many dots
113-
113+
114114
Returns
115115
-------
116116
f : plt.Figure - the figure instance for optional saving
@@ -240,7 +240,7 @@ def plot_z_score_contraction(
240240
tick_fontsize=12,
241241
color="#8f2727",
242242
n_col=None,
243-
n_row=None
243+
n_row=None,
244244
):
245245
"""Implements a graphical check for global model sensitivity by plotting the posterior
246246
z-score over the posterior contraction for each set of posterior samples in ``post_samples``
@@ -567,7 +567,7 @@ def plot_sbc_histograms(
567567
tick_fontsize=12,
568568
hist_color="#a34f4f",
569569
n_row=None,
570-
n_col=None
570+
n_col=None,
571571
):
572572
"""Creates and plots publication-ready histograms of rank statistics for simulation-based calibration
573573
(SBC) checks according to [1].
@@ -929,7 +929,7 @@ def plot_losses(
929929
)
930930
# Schmuck
931931
ax.set_xlabel("Training step #", fontsize=label_fontsize)
932-
ax.set_ylabel("Loss value", fontsize=label_fontsize)
932+
ax.set_ylabel("Value", fontsize=label_fontsize)
933933
sns.despine(ax=ax)
934934
ax.grid(alpha=grid_alpha)
935935
ax.set_title(train_losses.columns[i], fontsize=title_fontsize)
@@ -1061,7 +1061,7 @@ def plot_calibration_curves(
10611061
fig_size=None,
10621062
color="#8f2727",
10631063
n_row=None,
1064-
n_col=None
1064+
n_col=None,
10651065
):
10661066
"""Plots the calibration curves, the ECEs and the marginal histograms of predicted posterior model probabilities
10671067
for a model comparison problem. The marginal histograms inform about the fraction of predictions in each bin.
@@ -1114,7 +1114,6 @@ def plot_calibration_curves(
11141114
elif n_row is not None and n_col is None:
11151115
n_col = int(np.ceil(num_models / n_row))
11161116

1117-
11181117
# Compute calibration
11191118
cal_errs, probs_true, probs_pred = expected_calibration_error(true_models, pred_models, num_bins)
11201119

@@ -1273,7 +1272,13 @@ def plot_confusion_matrix(
12731272
for i in range(cm.shape[0]):
12741273
for j in range(cm.shape[1]):
12751274
ax.text(
1276-
j, i, format(cm[i, j], fmt), fontsize=value_fontsize, ha="center", va="center", color="white" if cm[i, j] > thresh else "black"
1275+
j,
1276+
i,
1277+
format(cm[i, j], fmt),
1278+
fontsize=value_fontsize,
1279+
ha="center",
1280+
va="center",
1281+
color="white" if cm[i, j] > thresh else "black",
12771282
)
12781283
if title:
12791284
ax.set_title("Confusion Matrix", fontsize=title_fontsize)

0 commit comments

Comments
 (0)