Skip to content

Commit bc90d05

Browse files
authored
Merge pull request #422 from bayesflow-org/enhanced-loss
Better Loss Plot
2 parents 25e632c + 72d9271 commit bc90d05

File tree

9 files changed

+203
-39
lines changed

9 files changed

+203
-39
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,13 @@ def loss(
1414
history: keras.callbacks.History,
1515
train_key: str = "loss",
1616
val_key: str = "val_loss",
17-
moving_average: bool = False,
18-
per_training_step: bool = False,
19-
ma_window_fraction: float = 0.01,
17+
smoothing_factor: float = 0.8,
2018
figsize: Sequence[float] = None,
2119
train_color: str = "#132a70",
2220
val_color: str = "black",
2321
lw_train: float = 2.0,
24-
lw_val: float = 3.0,
22+
lw_val: float = 2.0,
23+
grid_alpha: float = 0.2,
2524
legend_fontsize: int = 14,
2625
label_fontsize: int = 14,
2726
title_fontsize: int = 16,
@@ -38,24 +37,21 @@ def loss(
3837
The training loss key to look for in the history
3938
val_key : str, optional, default: "val_loss"
4039
The validation loss key to look for in the history
41-
moving_average : bool, optional, default: False
42-
A flag for adding a moving average line of the train_losses.
43-
per_training_step : bool, optional, default: False
44-
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
45-
ma_window_fraction : int, optional, default: 0.01
46-
Window size for the moving average as a fraction of total
47-
training steps.
40+
smoothing_factor : float, optional, default: 0.8
41+
If greater than zero, smooth the loss curves by applying an exponential moving average.
4842
figsize : tuple or None, optional, default: None
4943
The figure size passed to the ``matplotlib`` constructor.
5044
Inferred if ``None``
5145
train_color : str, optional, default: '#8f2727'
5246
The color for the train loss trajectory
53-
val_color : str, optional, default: black
47+
val_color : str, optional, default: None
5448
The color for the optional validation loss trajectory
5549
lw_train : int, optional, default: 2
5650
The linewidth for the training loss curve
5751
lw_val : int, optional, default: 3
5852
The linewidth for the validation loss curve
53+
grid_alpha : float, optional, default: 0.2
54+
The transparency of the background grid
5955
legend_fontsize : int, optional, default: 14
6056
The font size of the legend text
6157
label_fontsize : int, optional, default: 14
@@ -98,31 +94,60 @@ def loss(
9894

9995
# Loop through loss entries and populate plot
10096
for i, ax in enumerate(axes.flat):
101-
# Plot train curve
102-
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
103-
if moving_average and train_losses.columns[i] == "Loss":
104-
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
105-
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
106-
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
97+
if smoothing_factor > 0:
98+
# plot unsmoothed train loss
99+
ax.plot(
100+
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.3, label="Training"
101+
)
102+
103+
# plot smoothed train loss
104+
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
105+
ax.plot(
106+
train_step_index,
107+
smoothed_train_loss,
108+
color=train_color,
109+
lw=lw_train,
110+
alpha=0.8,
111+
label="Training (Moving Average)",
112+
)
113+
else:
114+
# plot unsmoothed train loss
115+
ax.plot(
116+
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training"
117+
)
107118

108119
# Plot optional val curve
109120
if val_losses is not None:
110-
if i < val_losses.shape[1]:
111-
ax.plot(
112-
val_step_index,
113-
val_losses.iloc[:, i],
114-
linestyle="--",
115-
marker="o",
116-
color=val_color,
117-
lw=lw_val,
118-
label="Validation",
119-
)
121+
if val_color is not None:
122+
if smoothing_factor > 0:
123+
# plot unsmoothed val loss
124+
ax.plot(
125+
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation"
126+
)
127+
128+
# plot smoothed val loss
129+
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
130+
ax.plot(
131+
val_step_index,
132+
smoothed_val_loss,
133+
color=val_color,
134+
lw=lw_val,
135+
alpha=0.8,
136+
label="Validation (Moving Average)",
137+
)
138+
else:
139+
# plot unsmoothed val loss
140+
ax.plot(
141+
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation"
142+
)
120143

121144
sns.despine(ax=ax)
122-
ax.grid(alpha=0.5)
145+
ax.grid(alpha=grid_alpha)
123146

124-
# Only add legend if there is a validation curve
125-
if val_losses is not None or moving_average:
147+
ax.set_xlim(train_step_index[0], train_step_index[-1])
148+
149+
# Only add the legend if there are multiple curves
150+
if val_losses is not None or smoothing_factor > 0:
126151
ax.legend(fontsize=legend_fontsize)
127152

128153
# Add labels, titles, and set font sizes
@@ -131,7 +156,7 @@ def loss(
131156
num_row=num_row,
132157
num_col=1,
133158
title=["Loss Trajectory"],
134-
xlabel="Training step #" if per_training_step else "Training epoch #",
159+
xlabel="Training epoch #",
135160
ylabel="Value",
136161
title_fontsize=title_fontsize,
137162
label_fontsize=label_fontsize,

bayesflow/utils/plot_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import matplotlib.pyplot as plt
55
import seaborn as sns
66

7+
from matplotlib.collections import LineCollection
8+
from matplotlib.colors import Normalize
9+
from matplotlib.patches import Rectangle
10+
from matplotlib.legend_handler import HandlerPatch
11+
712
from .validators import check_estimates_prior_shapes
813
from .dict_utils import dicts_to_arrays
914

@@ -260,3 +265,96 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
260265
alpha=0.9,
261266
linestyle="dashed",
262267
)
268+
269+
270+
def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
271+
"""
272+
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
273+
"""
274+
if ax is None:
275+
ax = plt.gca()
276+
277+
# Default color value = y
278+
if c is None:
279+
c = y
280+
281+
# Create segments for LineCollection
282+
points = np.array([x, y]).T.reshape(-1, 1, 2)
283+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
284+
285+
norm = Normalize(np.min(c), np.max(c))
286+
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha)
287+
288+
ax.add_collection(lc)
289+
ax.set_xlim(np.min(x), np.max(x))
290+
ax.set_ylim(np.min(y), np.max(y))
291+
return lc
292+
293+
294+
def gradient_legend(ax, label, cmap, norm, loc="upper right"):
295+
"""
296+
Adds a single gradient swatch to the legend of the given Axes.
297+
298+
Parameters
299+
----------
300+
- ax: matplotlib Axes
301+
- label: str, label to display in the legend
302+
- cmap: matplotlib colormap
303+
- norm: matplotlib Normalize object
304+
- loc: legend location (default 'upper right')
305+
"""
306+
307+
# Custom dummy handle to represent the gradient
308+
class _GradientSwatch(Rectangle):
309+
pass
310+
311+
# Custom legend handler that draws a horizontal gradient
312+
class _HandlerGradient(HandlerPatch):
313+
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
314+
gradient = np.linspace(0, 1, 256).reshape(1, -1)
315+
im = ax.imshow(
316+
gradient,
317+
aspect="auto",
318+
extent=[xdescent, xdescent + width, ydescent, ydescent + height],
319+
transform=trans,
320+
cmap=cmap,
321+
norm=norm,
322+
)
323+
return [im]
324+
325+
# Add to existing legend entries
326+
handles, labels = ax.get_legend_handles_labels()
327+
handles.append(_GradientSwatch((0, 0), 1, 1))
328+
labels.append(label)
329+
330+
ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()})
331+
332+
333+
def add_gradient_plot(
334+
x,
335+
y,
336+
ax,
337+
cmap: str = "viridis",
338+
lw: float = 3.0,
339+
marker: bool = True,
340+
marker_type: str = "o",
341+
marker_size: int = 34,
342+
alpha: float = 1,
343+
label: str = "Validation",
344+
):
345+
gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax)
346+
347+
# Optionally add markers
348+
if marker:
349+
ax.scatter(
350+
x,
351+
y,
352+
c=x,
353+
cmap=cmap,
354+
marker=marker_type,
355+
s=marker_size,
356+
zorder=10,
357+
edgecolors="none",
358+
label=label,
359+
alpha=0.01,
360+
)

bayesflow/workflows/basic_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def plot_default_diagnostics(
348348
) -> dict[str, plt.Figure]:
349349
"""
350350
Generates default diagnostic plots to evaluate the quality of inference. The function produces several
351-
diagnostic plots, including:
351+
diagnostic plots, including
352352
- Loss history (if training history is available).
353353
- Parameter recovery plots.
354354
- Calibration ECDF plots.

tests/conftest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import logging
2-
31
import keras
2+
import logging
3+
import matplotlib
44
import pytest
55

66
BACKENDS = ["jax", "numpy", "tensorflow", "torch"]
@@ -17,6 +17,16 @@ def pytest_runtest_setup(item):
1717
if test_backends and backend not in test_backends:
1818
pytest.skip(f"Skipping backend '{backend}' for test {item}, which is registered for backends {test_backends}.")
1919

20+
# use a non-GUI plotting backend for tests
21+
matplotlib.use("Agg")
22+
23+
24+
def pytest_runtest_teardown(item, nextitem):
25+
import matplotlib.pyplot as plt
26+
27+
# close all plots at the end of each test
28+
plt.close("all")
29+
2030

2131
def pytest_make_parametrize_id(config, val, argname):
2232
return f"{argname}={repr(val)}"

tests/test_diagnostics/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import keras
12
import numpy as np
23
import pytest
34
from bayesflow.utils.numpy_utils import softmax
@@ -61,3 +62,19 @@ def pred_models(true_models):
6162
pred_models = np.random.normal(loc=true_models)
6263
pred_models = softmax(pred_models, axis=-1)
6364
return pred_models
65+
66+
67+
@pytest.fixture()
68+
def history():
69+
h = keras.callbacks.History()
70+
71+
step = np.linspace(0, 1, 10_000)
72+
train_loss = (1.0 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
73+
validation_loss = 0.1 + (0.75 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
74+
75+
h.history = {
76+
"loss": train_loss.tolist(),
77+
"val_loss": validation_loss.tolist(),
78+
}
79+
80+
return h

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ def num_variables(x: dict):
66
return sum(arr.shape[-1] for arr in x.values())
77

88

9+
def test_backend():
10+
import matplotlib.pyplot as plt
11+
12+
# if the local testing backend is not Agg
13+
# then you may run into issues once you run workflow tests
14+
# on GitHub, since these use the Agg backend
15+
assert plt.get_backend() == "Agg"
16+
17+
918
def test_calibration_ecdf(random_estimates, random_targets, var_names):
1019
# basic functionality: automatic variable names
1120
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
@@ -45,6 +54,12 @@ def test_calibration_histogram(random_estimates, random_targets):
4554
assert out.axes[0].title._text == "beta_0"
4655

4756

57+
def test_loss(history):
58+
out = bf.diagnostics.loss(history)
59+
assert len(out.axes) == 1
60+
assert out.axes[0].title._text == "Loss Trajectory"
61+
62+
4863
def test_recovery(random_estimates, random_targets):
4964
# basic functionality: automatic variable names
5065
out = bf.diagnostics.plots.recovery(random_estimates, random_targets)

tests/test_networks/test_inference_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def f(x):
9898
numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True)
9999

100100
# output should be identical, otherwise this test does not work (e.g. for stochastic networks)
101-
assert keras.ops.all(keras.ops.isclose(output, numerical_output))
101+
assert_allclose(output, numerical_output)
102102

103103
log_prob = generative_inference_network.base_distribution.log_prob(output)
104104

tests/test_networks/test_summary_networks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ def test_variable_set_size(summary_network, random_set):
4949
summary_network.build(keras.ops.shape(random_set))
5050

5151
# run with another set size
52-
for _ in range(10):
52+
for s in [3, 4, 5]:
5353
b = keras.ops.shape(random_set)[0]
54-
s = np.random.randint(1, 10)
5554
new_input = keras.ops.zeros((b, s, keras.ops.shape(random_set)[2]))
5655
summary_network(new_input)
5756

tests/test_workflows/test_basic_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import bayesflow as bf
22

33

4-
def test_classifier_two_sample_test(inference_network, summary_network):
4+
def test_basic_workflow(inference_network, summary_network):
55
workflow = bf.BasicWorkflow(
66
inference_network=inference_network,
77
summary_network=summary_network,

0 commit comments

Comments
 (0)