Skip to content

Commit c0a4f25

Browse files
committed
Refactor loss plotting logic and remove unused parameters.
Simplified the loss plotting code by consolidating duplication and aligning the handling of smoothing logic. Removed unused arguments like markers and colormap, reducing potential confusion in the API. Updated comments and improved code readability for maintainability.
1 parent f9f807f commit c0a4f25

File tree

1 file changed

+43
-65
lines changed

1 file changed

+43
-65
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 43 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import keras.src.callbacks
99

10-
from ...utils.plot_utils import make_figure, add_titles_and_labels, add_gradient_plot
10+
from ...utils.plot_utils import make_figure, add_titles_and_labels
1111

1212

1313
def loss(
@@ -17,13 +17,9 @@ def loss(
1717
smoothing_factor: float = 0.8,
1818
figsize: Sequence[float] = None,
1919
train_color: str = "#132a70",
20-
val_color: str = None,
21-
val_colormap: str = "viridis",
20+
val_color: str = "black",
2221
lw_train: float = 2.0,
23-
lw_val: float = 3.0,
24-
marker: bool = True,
25-
val_marker_type: str = ".",
26-
val_marker_size: int = 34,
22+
lw_val: float = 2.0,
2723
grid_alpha: float = 0.2,
2824
legend_fontsize: int = 14,
2925
label_fontsize: int = 14,
@@ -41,29 +37,19 @@ def loss(
4137
The training loss key to look for in the history
4238
val_key : str, optional, default: "val_loss"
4339
The validation loss key to look for in the history
44-
moving_average : bool, optional, default: False
45-
A flag for adding an exponential moving average line of the train_losses.
46-
moving_average_alpha : int, optional, default: 0.8
47-
Smoothing factor for the moving average.
40+
smoothing_factor : float, optional, default: 0.8
41+
If greater than zero, smooth the loss curves by applying an exponential moving average.
4842
figsize : tuple or None, optional, default: None
4943
The figure size passed to the ``matplotlib`` constructor.
5044
Inferred if ``None``
5145
train_color : str, optional, default: '#8f2727'
5246
The color for the train loss trajectory
5347
val_color : str, optional, default: None
5448
The color for the optional validation loss trajectory
55-
val_colormap : str, optional, default: "viridis"
56-
The colormap for the optional validation loss trajectory
5749
lw_train : int, optional, default: 2
5850
The linewidth for the training loss curve
5951
lw_val : int, optional, default: 3
6052
The linewidth for the validation loss curve
61-
marker : bool, optional, default: False
62-
A flag for whether marker should be added in the validation loss trajectory
63-
val_marker_type : str, optional, default: o
64-
The marker type for the validation loss curve
65-
val_marker_size : int, optional, default: 34
66-
The marker size for the validation loss curve
6753
grid_alpha : float, optional, default: 0.2
6854
The transparency of the background grid
6955
legend_fontsize : int, optional, default: 14
@@ -108,68 +94,60 @@ def loss(
10894

10995
# Loop through loss entries and populate plot
11096
for i, ax in enumerate(axes.flat):
111-
# Plot train curve
112-
ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.05, label="Training")
113-
if moving_average:
114-
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
115-
ax.plot(train_step_index, smoothed_train_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
97+
if smoothing_factor > 0:
98+
# plot unsmoothed train loss
99+
ax.plot(
100+
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.3, label="Training"
101+
)
102+
103+
# plot smoothed train loss
104+
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
105+
ax.plot(
106+
train_step_index,
107+
smoothed_train_loss,
108+
color=train_color,
109+
lw=lw_train,
110+
alpha=0.8,
111+
label="Training (Moving Average)",
112+
)
113+
else:
114+
# plot unsmoothed train loss
115+
ax.plot(
116+
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training"
117+
)
116118

117119
# Plot optional val curve
118120
if val_losses is not None:
119121
if val_color is not None:
120-
ax.plot(
121-
val_step_index,
122-
val_losses.iloc[:, 0],
123-
linestyle="--",
124-
marker=val_marker_type if marker else None,
125-
color=val_color,
126-
lw=lw_val,
127-
alpha=0.2,
128-
label="Validation",
129-
)
130-
if moving_average:
131-
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
122+
if smoothing_factor > 0:
123+
# plot unsmoothed val loss
124+
ax.plot(
125+
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation"
126+
)
127+
128+
# plot smoothed val loss
129+
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
132130
ax.plot(
133131
val_step_index,
134132
smoothed_val_loss,
135133
color=val_color,
136134
lw=lw_val,
135+
alpha=0.8,
137136
label="Validation (Moving Average)",
138137
)
139-
else:
140-
# Make gradient lines
141-
add_gradient_plot(
142-
val_step_index,
143-
val_losses.iloc[:, 0],
144-
ax,
145-
val_colormap,
146-
lw_val,
147-
marker,
148-
val_marker_type,
149-
val_marker_size,
150-
alpha=0.05,
151-
label="Validation",
152-
)
153-
if moving_average:
154-
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
155-
add_gradient_plot(
156-
val_step_index,
157-
smoothed_val_loss,
158-
ax,
159-
val_colormap,
160-
lw_val,
161-
marker,
162-
val_marker_type,
163-
val_marker_size,
164-
alpha=1,
165-
label="Validation (Moving Average)",
138+
else:
139+
# plot unsmoothed val loss
140+
ax.plot(
141+
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation"
166142
)
167143

168144
sns.despine(ax=ax)
169145
ax.grid(alpha=grid_alpha)
170146

171-
# Only add legend if there is a validation curve
172-
if val_losses is not None or moving_average:
147+
ax.set_xlim(train_step_index[0], train_step_index[-1])
148+
149+
# Only add the legend if there are multiple curves
150+
if val_losses is not None or smoothing_factor > 0:
173151
ax.legend(fontsize=legend_fontsize)
174152

175153
# Add labels, titles, and set font sizes

0 commit comments

Comments
 (0)