Skip to content

Commit 8b6769d

Browse files
committed
Merge branch 'dev' of https://github.com/stefanradev93/BayesFlow into dev
2 parents c6a16e3 + bca8225 commit 8b6769d

File tree

9 files changed

+205
-28
lines changed

9 files changed

+205
-28
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 59 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@ def loss(
1515
train_key: str = "loss",
1616
val_key: str = "val_loss",
1717
per_training_step: bool = False,
18+
smoothing_factor: float = 0.8,
1819
figsize: Sequence[float] = None,
1920
train_color: str = "#132a70",
2021
val_color: str = "black",
21-
lw_train: float = 2.5,
22-
lw_val: float = 2.5,
22+
lw_train: float = 2.0,
23+
lw_val: float = 2.0,
24+
grid_alpha: float = 0.2,
2325
legend_fontsize: int = 14,
2426
label_fontsize: int = 14,
2527
title_fontsize: int = 16,
@@ -38,17 +40,21 @@ def loss(
3840
The validation loss key to look for in the history
3941
per_training_step : bool, optional, default: False
4042
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
43+
smoothing_factor : float, optional, default: 0.8
44+
If greater than zero, smooth the loss curves by applying an exponential moving average.
4145
figsize : tuple or None, optional, default: None
4246
The figure size passed to the ``matplotlib`` constructor.
4347
Inferred if ``None``
44-
train_color : str, optional, default: '#8f2727'
48+
train_color : str, optional, default: '#132a70'
4549
The color for the train loss trajectory
46-
val_color : str, optional, default: black
50+
val_color : str, optional, default: None
4751
The color for the optional validation loss trajectory
48-
lw_train : int, optional, default: 1
52+
lw_train : int, optional, default: 2
4953
The linewidth for the training loss curve
5054
lw_val : int, optional, default: 2
5155
The linewidth for the validation loss curve
56+
grid_alpha : float, optional, default: 0.2
57+
The transparency of the background grid
5258
legend_fontsize : int, optional, default: 14
5359
The font size of the legend text
5460
label_fontsize : int, optional, default: 14
@@ -91,28 +97,60 @@ def loss(
9197

9298
# Loop through loss entries and populate plot
9399
for i, ax in enumerate(axes.flat):
94-
# Plot train curve
95-
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
100+
if smoothing_factor > 0:
101+
# plot unsmoothed train loss
102+
ax.plot(
103+
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.3, label="Training"
104+
)
105+
106+
# plot smoothed train loss
107+
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
108+
ax.plot(
109+
train_step_index,
110+
smoothed_train_loss,
111+
color=train_color,
112+
lw=lw_train,
113+
alpha=0.8,
114+
label="Training (Moving Average)",
115+
)
116+
else:
117+
# plot unsmoothed train loss
118+
ax.plot(
119+
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training"
120+
)
96121

97122
# Plot optional val curve
98123
if val_losses is not None:
99-
if i < val_losses.shape[1]:
100-
ax.plot(
101-
val_step_index,
102-
val_losses.iloc[:, i],
103-
linestyle="--",
104-
marker="o",
105-
markersize=5,
106-
color=val_color,
107-
lw=lw_val,
108-
label="Validation",
109-
)
124+
if val_color is not None:
125+
if smoothing_factor > 0:
126+
# plot unsmoothed val loss
127+
ax.plot(
128+
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation"
129+
)
130+
131+
# plot smoothed val loss
132+
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
133+
ax.plot(
134+
val_step_index,
135+
smoothed_val_loss,
136+
color=val_color,
137+
lw=lw_val,
138+
alpha=0.8,
139+
label="Validation (Moving Average)",
140+
)
141+
else:
142+
# plot unsmoothed val loss
143+
ax.plot(
144+
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation"
145+
)
110146

111147
sns.despine(ax=ax)
112-
ax.grid(alpha=0.5)
148+
ax.grid(alpha=grid_alpha)
113149

114-
# Only add legend if there is a validation curve
115-
if val_losses is not None:
150+
ax.set_xlim(train_step_index[0], train_step_index[-1])
151+
152+
# Only add the legend if there are multiple curves
153+
if val_losses is not None or smoothing_factor > 0:
116154
ax.legend(fontsize=legend_fontsize)
117155

118156
# Add labels, titles, and set font sizes

bayesflow/utils/plot_utils.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44
import matplotlib.pyplot as plt
55
import seaborn as sns
66

7+
from matplotlib.collections import LineCollection
8+
from matplotlib.colors import Normalize
9+
from matplotlib.patches import Rectangle
10+
from matplotlib.legend_handler import HandlerPatch
11+
712
from .validators import check_estimates_prior_shapes
813
from .dict_utils import dicts_to_arrays
914

@@ -260,3 +265,96 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
260265
alpha=0.9,
261266
linestyle="dashed",
262267
)
268+
269+
270+
def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
271+
"""
272+
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
273+
"""
274+
if ax is None:
275+
ax = plt.gca()
276+
277+
# Default color value = y
278+
if c is None:
279+
c = y
280+
281+
# Create segments for LineCollection
282+
points = np.array([x, y]).T.reshape(-1, 1, 2)
283+
segments = np.concatenate([points[:-1], points[1:]], axis=1)
284+
285+
norm = Normalize(np.min(c), np.max(c))
286+
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha)
287+
288+
ax.add_collection(lc)
289+
ax.set_xlim(np.min(x), np.max(x))
290+
ax.set_ylim(np.min(y), np.max(y))
291+
return lc
292+
293+
294+
def gradient_legend(ax, label, cmap, norm, loc="upper right"):
295+
"""
296+
Adds a single gradient swatch to the legend of the given Axes.
297+
298+
Parameters
299+
----------
300+
- ax: matplotlib Axes
301+
- label: str, label to display in the legend
302+
- cmap: matplotlib colormap
303+
- norm: matplotlib Normalize object
304+
- loc: legend location (default 'upper right')
305+
"""
306+
307+
# Custom dummy handle to represent the gradient
308+
class _GradientSwatch(Rectangle):
309+
pass
310+
311+
# Custom legend handler that draws a horizontal gradient
312+
class _HandlerGradient(HandlerPatch):
313+
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
314+
gradient = np.linspace(0, 1, 256).reshape(1, -1)
315+
im = ax.imshow(
316+
gradient,
317+
aspect="auto",
318+
extent=[xdescent, xdescent + width, ydescent, ydescent + height],
319+
transform=trans,
320+
cmap=cmap,
321+
norm=norm,
322+
)
323+
return [im]
324+
325+
# Add to existing legend entries
326+
handles, labels = ax.get_legend_handles_labels()
327+
handles.append(_GradientSwatch((0, 0), 1, 1))
328+
labels.append(label)
329+
330+
ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()})
331+
332+
333+
def add_gradient_plot(
334+
x,
335+
y,
336+
ax,
337+
cmap: str = "viridis",
338+
lw: float = 3.0,
339+
marker: bool = True,
340+
marker_type: str = "o",
341+
marker_size: int = 34,
342+
alpha: float = 1,
343+
label: str = "Validation",
344+
):
345+
gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax)
346+
347+
# Optionally add markers
348+
if marker:
349+
ax.scatter(
350+
x,
351+
y,
352+
c=x,
353+
cmap=cmap,
354+
marker=marker_type,
355+
s=marker_size,
356+
zorder=10,
357+
edgecolors="none",
358+
label=label,
359+
alpha=0.01,
360+
)

bayesflow/workflows/basic_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ def plot_default_diagnostics(
348348
) -> dict[str, plt.Figure]:
349349
"""
350350
Generates default diagnostic plots to evaluate the quality of inference. The function produces several
351-
diagnostic plots, including:
351+
diagnostic plots, including
352352
- Loss history (if training history is available).
353353
- Parameter recovery plots.
354354
- Calibration ECDF plots.

tests/conftest.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import logging
2-
31
import keras
2+
import logging
3+
import matplotlib
44
import pytest
55

66
BACKENDS = ["jax", "numpy", "tensorflow", "torch"]
@@ -25,6 +25,16 @@ def pytest_runtest_setup(item):
2525

2626
jax.config.update("jax_traceback_filtering", "off")
2727

28+
# use a non-GUI plotting backend for tests
29+
matplotlib.use("Agg")
30+
31+
32+
def pytest_runtest_teardown(item, nextitem):
33+
import matplotlib.pyplot as plt
34+
35+
# close all plots at the end of each test
36+
plt.close("all")
37+
2838

2939
def pytest_make_parametrize_id(config, val, argname):
3040
return f"{argname}={repr(val)}"

tests/test_diagnostics/conftest.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import keras
12
import numpy as np
23
import pytest
34
from bayesflow.utils.numpy_utils import softmax
@@ -61,3 +62,19 @@ def pred_models(true_models):
6162
pred_models = np.random.normal(loc=true_models)
6263
pred_models = softmax(pred_models, axis=-1)
6364
return pred_models
65+
66+
67+
@pytest.fixture()
68+
def history():
69+
h = keras.callbacks.History()
70+
71+
step = np.linspace(0, 1, 10_000)
72+
train_loss = (1.0 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
73+
validation_loss = 0.1 + (0.75 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
74+
75+
h.history = {
76+
"loss": train_loss.tolist(),
77+
"val_loss": validation_loss.tolist(),
78+
}
79+
80+
return h

tests/test_diagnostics/test_diagnostics_plots.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@ def num_variables(x: dict):
66
return sum(arr.shape[-1] for arr in x.values())
77

88

9+
def test_backend():
10+
import matplotlib.pyplot as plt
11+
12+
# if the local testing backend is not Agg
13+
# then you may run into issues once you run workflow tests
14+
# on GitHub, since these use the Agg backend
15+
assert plt.get_backend() == "Agg"
16+
17+
918
def test_calibration_ecdf(random_estimates, random_targets, var_names):
1019
# basic functionality: automatic variable names
1120
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
@@ -45,6 +54,12 @@ def test_calibration_histogram(random_estimates, random_targets):
4554
assert out.axes[0].title._text == "beta_0"
4655

4756

57+
def test_loss(history):
58+
out = bf.diagnostics.loss(history)
59+
assert len(out.axes) == 1
60+
assert out.axes[0].title._text == "Loss Trajectory"
61+
62+
4863
def test_recovery(random_estimates, random_targets):
4964
# basic functionality: automatic variable names
5065
out = bf.diagnostics.plots.recovery(random_estimates, random_targets)

tests/test_networks/test_inference_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def f(x):
9696
numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True)
9797

9898
# output should be identical, otherwise this test does not work (e.g. for stochastic networks)
99-
assert keras.ops.all(keras.ops.isclose(output, numerical_output))
99+
assert_allclose(output, numerical_output)
100100

101101
log_prob = generative_inference_network.base_distribution.log_prob(output)
102102

tests/test_networks/test_summary_networks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ def test_variable_set_size(summary_network, random_set):
4747
summary_network.build(keras.ops.shape(random_set))
4848

4949
# run with another set size
50-
for _ in range(10):
50+
for s in [3, 4, 5]:
5151
b = keras.ops.shape(random_set)[0]
52-
s = np.random.randint(1, 10)
5352
new_input = keras.ops.zeros((b, s, keras.ops.shape(random_set)[2]))
5453
summary_network(new_input)
5554

tests/test_workflows/test_basic_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import bayesflow as bf
22

33

4-
def test_classifier_two_sample_test(inference_network, summary_network):
4+
def test_basic_workflow(inference_network, summary_network):
55
workflow = bf.BasicWorkflow(
66
inference_network=inference_network,
77
summary_network=summary_network,

0 commit comments

Comments
 (0)