Skip to content

Commit 8b95883

Browse files
committed
Simplify 3-way nested if
1 parent c175ed4 commit 8b95883

File tree

1 file changed

+28
-27
lines changed

1 file changed

+28
-27
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -114,42 +114,43 @@ def loss(
114114
label="Training (Moving Average)",
115115
)
116116
else:
117-
# plot unsmoothed train loss
117+
# Plot unsmoothed train loss
118118
ax.plot(
119119
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training"
120120
)
121121

122-
# Plot optional val curve
123-
if val_losses is not None:
124-
if val_color is not None:
125-
if smoothing_factor > 0:
126-
# plot unsmoothed val loss
127-
ax.plot(
128-
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation"
129-
)
130-
131-
# plot smoothed val loss
132-
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
133-
ax.plot(
134-
val_step_index,
135-
smoothed_val_loss,
136-
color=val_color,
137-
lw=lw_val,
138-
alpha=0.8,
139-
label="Validation (Moving Average)",
140-
)
141-
else:
142-
# plot unsmoothed val loss
143-
ax.plot(
144-
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation"
145-
)
122+
# Only plot if we actually have validation losses and a color assigned
123+
if val_losses is not None and val_color is not None:
124+
alpha_unsmoothed = 0.3 if smoothing_factor > 0 else 0.8
146125

126+
# Plot unsmoothed val loss
127+
ax.plot(
128+
val_step_index,
129+
val_losses.iloc[:, 0],
130+
color=val_color,
131+
lw=lw_val,
132+
alpha=alpha_unsmoothed,
133+
label="Validation",
134+
)
135+
136+
# if requested, plot a second, smoothed curve
137+
if smoothing_factor > 0:
138+
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
139+
ax.plot(
140+
val_step_index,
141+
smoothed_val_loss,
142+
color=val_color,
143+
lw=lw_val,
144+
alpha=0.8,
145+
label="Validation (Moving Average)",
146+
)
147+
148+
# rest of the styling
147149
sns.despine(ax=ax)
148150
ax.grid(alpha=grid_alpha)
149-
150151
ax.set_xlim(train_step_index[0], train_step_index[-1])
151152

152-
# Only add the legend if there are multiple curves
153+
# legend only if there's at least one validation curve or smoothing was on
153154
if val_losses is not None or smoothing_factor > 0:
154155
ax.legend(fontsize=legend_fontsize)
155156

0 commit comments

Comments
 (0)