Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 57 additions & 32 deletions bayesflow/diagnostics/plots/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,13 @@
history: keras.callbacks.History,
train_key: str = "loss",
val_key: str = "val_loss",
moving_average: bool = False,
per_training_step: bool = False,
ma_window_fraction: float = 0.01,
smoothing_factor: float = 0.8,
figsize: Sequence[float] = None,
train_color: str = "#132a70",
val_color: str = "black",
lw_train: float = 2.0,
lw_val: float = 3.0,
lw_val: float = 2.0,
grid_alpha: float = 0.2,
legend_fontsize: int = 14,
label_fontsize: int = 14,
title_fontsize: int = 16,
Expand All @@ -38,24 +37,21 @@
The training loss key to look for in the history
val_key : str, optional, default: "val_loss"
The validation loss key to look for in the history
moving_average : bool, optional, default: False
A flag for adding a moving average line of the train_losses.
per_training_step : bool, optional, default: False
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
ma_window_fraction : int, optional, default: 0.01
Window size for the moving average as a fraction of total
training steps.
smoothing_factor : float, optional, default: 0.8
If greater than zero, smooth the loss curves by applying an exponential moving average.
figsize : tuple or None, optional, default: None
The figure size passed to the ``matplotlib`` constructor.
Inferred if ``None``
train_color : str, optional, default: '#8f2727'
The color for the train loss trajectory
val_color : str, optional, default: black
val_color : str, optional, default: None
The color for the optional validation loss trajectory
lw_train : int, optional, default: 2
The linewidth for the training loss curve
lw_val : int, optional, default: 3
The linewidth for the validation loss curve
grid_alpha : float, optional, default: 0.2
The transparency of the background grid
legend_fontsize : int, optional, default: 14
The font size of the legend text
label_fontsize : int, optional, default: 14
Expand Down Expand Up @@ -98,31 +94,60 @@

# Loop through loss entries and populate plot
for i, ax in enumerate(axes.flat):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average and train_losses.columns[i] == "Loss":
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
if smoothing_factor > 0:
# plot unsmoothed train loss
ax.plot(
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.3, label="Training"
)

# plot smoothed train loss
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
ax.plot(
train_step_index,
smoothed_train_loss,
color=train_color,
lw=lw_train,
alpha=0.8,
label="Training (Moving Average)",
)
else:
# plot unsmoothed train loss
ax.plot(

Check warning on line 115 in bayesflow/diagnostics/plots/loss.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/diagnostics/plots/loss.py#L115

Added line #L115 was not covered by tests
train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training"
)

# Plot optional val curve
if val_losses is not None:
if i < val_losses.shape[1]:
ax.plot(
val_step_index,
val_losses.iloc[:, i],
linestyle="--",
marker="o",
color=val_color,
lw=lw_val,
label="Validation",
)
if val_color is not None:
if smoothing_factor > 0:
# plot unsmoothed val loss
ax.plot(
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation"
)

# plot smoothed val loss
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean()
ax.plot(
val_step_index,
smoothed_val_loss,
color=val_color,
lw=lw_val,
alpha=0.8,
label="Validation (Moving Average)",
)
else:
# plot unsmoothed val loss
ax.plot(

Check warning on line 140 in bayesflow/diagnostics/plots/loss.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/diagnostics/plots/loss.py#L140

Added line #L140 was not covered by tests
val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation"
)

sns.despine(ax=ax)
ax.grid(alpha=0.5)
ax.grid(alpha=grid_alpha)

# Only add legend if there is a validation curve
if val_losses is not None or moving_average:
ax.set_xlim(train_step_index[0], train_step_index[-1])

# Only add the legend if there are multiple curves
if val_losses is not None or smoothing_factor > 0:
ax.legend(fontsize=legend_fontsize)

# Add labels, titles, and set font sizes
Expand All @@ -131,7 +156,7 @@
num_row=num_row,
num_col=1,
title=["Loss Trajectory"],
xlabel="Training step #" if per_training_step else "Training epoch #",
xlabel="Training epoch #",
ylabel="Value",
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
Expand Down
98 changes: 98 additions & 0 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerPatch

from .validators import check_estimates_prior_shapes
from .dict_utils import dicts_to_arrays

Expand Down Expand Up @@ -260,3 +265,96 @@
alpha=0.9,
linestyle="dashed",
)


def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
"""
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
"""
if ax is None:
ax = plt.gca()

Check warning on line 275 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L274-L275

Added lines #L274 - L275 were not covered by tests

# Default color value = y
if c is None:
c = y

Check warning on line 279 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L278-L279

Added lines #L278 - L279 were not covered by tests

# Create segments for LineCollection
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

Check warning on line 283 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L282-L283

Added lines #L282 - L283 were not covered by tests

norm = Normalize(np.min(c), np.max(c))
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha)

Check warning on line 286 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L285-L286

Added lines #L285 - L286 were not covered by tests

ax.add_collection(lc)
ax.set_xlim(np.min(x), np.max(x))
ax.set_ylim(np.min(y), np.max(y))
return lc

Check warning on line 291 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L288-L291

Added lines #L288 - L291 were not covered by tests


def gradient_legend(ax, label, cmap, norm, loc="upper right"):
"""
Adds a single gradient swatch to the legend of the given Axes.

Parameters
----------
- ax: matplotlib Axes
- label: str, label to display in the legend
- cmap: matplotlib colormap
- norm: matplotlib Normalize object
- loc: legend location (default 'upper right')
"""

# Custom dummy handle to represent the gradient
class _GradientSwatch(Rectangle):
pass

Check warning on line 309 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L308-L309

Added lines #L308 - L309 were not covered by tests

# Custom legend handler that draws a horizontal gradient
class _HandlerGradient(HandlerPatch):
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
gradient = np.linspace(0, 1, 256).reshape(1, -1)
im = ax.imshow(

Check warning on line 315 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L312-L315

Added lines #L312 - L315 were not covered by tests
gradient,
aspect="auto",
extent=[xdescent, xdescent + width, ydescent, ydescent + height],
transform=trans,
cmap=cmap,
norm=norm,
)
return [im]

Check warning on line 323 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L323

Added line #L323 was not covered by tests

# Add to existing legend entries
handles, labels = ax.get_legend_handles_labels()
handles.append(_GradientSwatch((0, 0), 1, 1))
labels.append(label)

Check warning on line 328 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L326-L328

Added lines #L326 - L328 were not covered by tests

ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()})

Check warning on line 330 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L330

Added line #L330 was not covered by tests


def add_gradient_plot(
x,
y,
ax,
cmap: str = "viridis",
lw: float = 3.0,
marker: bool = True,
marker_type: str = "o",
marker_size: int = 34,
alpha: float = 1,
label: str = "Validation",
):
gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax)

Check warning on line 345 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L345

Added line #L345 was not covered by tests

# Optionally add markers
if marker:
ax.scatter(

Check warning on line 349 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L348-L349

Added lines #L348 - L349 were not covered by tests
x,
y,
c=x,
cmap=cmap,
marker=marker_type,
s=marker_size,
zorder=10,
edgecolors="none",
label=label,
alpha=0.01,
)
2 changes: 1 addition & 1 deletion bayesflow/workflows/basic_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ def plot_default_diagnostics(
) -> dict[str, plt.Figure]:
"""
Generates default diagnostic plots to evaluate the quality of inference. The function produces several
diagnostic plots, including:
diagnostic plots, including
- Loss history (if training history is available).
- Parameter recovery plots.
- Calibration ECDF plots.
Expand Down
14 changes: 12 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging

import keras
import logging
import matplotlib
import pytest

BACKENDS = ["jax", "numpy", "tensorflow", "torch"]
Expand All @@ -17,6 +17,16 @@ def pytest_runtest_setup(item):
if test_backends and backend not in test_backends:
pytest.skip(f"Skipping backend '{backend}' for test {item}, which is registered for backends {test_backends}.")

# use a non-GUI plotting backend for tests
matplotlib.use("Agg")


def pytest_runtest_teardown(item, nextitem):
import matplotlib.pyplot as plt

# close all plots at the end of each test
plt.close("all")


def pytest_make_parametrize_id(config, val, argname):
return f"{argname}={repr(val)}"
Expand Down
17 changes: 17 additions & 0 deletions tests/test_diagnostics/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import keras
import numpy as np
import pytest
from bayesflow.utils.numpy_utils import softmax
Expand Down Expand Up @@ -61,3 +62,19 @@ def pred_models(true_models):
pred_models = np.random.normal(loc=true_models)
pred_models = softmax(pred_models, axis=-1)
return pred_models


@pytest.fixture()
def history():
h = keras.callbacks.History()

step = np.linspace(0, 1, 10_000)
train_loss = (1.0 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)
validation_loss = 0.1 + (0.75 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape)

h.history = {
"loss": train_loss.tolist(),
"val_loss": validation_loss.tolist(),
}

return h
15 changes: 15 additions & 0 deletions tests/test_diagnostics/test_diagnostics_plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@ def num_variables(x: dict):
return sum(arr.shape[-1] for arr in x.values())


def test_backend():
import matplotlib.pyplot as plt

# if the local testing backend is not Agg
# then you may run into issues once you run workflow tests
# on GitHub, since these use the Agg backend
assert plt.get_backend() == "Agg"


def test_calibration_ecdf(random_estimates, random_targets, var_names):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)
Expand Down Expand Up @@ -45,6 +54,12 @@ def test_calibration_histogram(random_estimates, random_targets):
assert out.axes[0].title._text == "beta_0"


def test_loss(history):
out = bf.diagnostics.loss(history)
assert len(out.axes) == 1
assert out.axes[0].title._text == "Loss Trajectory"


def test_recovery(random_estimates, random_targets):
# basic functionality: automatic variable names
out = bf.diagnostics.plots.recovery(random_estimates, random_targets)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_networks/test_inference_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def f(x):
numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True)

# output should be identical, otherwise this test does not work (e.g. for stochastic networks)
assert keras.ops.all(keras.ops.isclose(output, numerical_output))
assert_allclose(output, numerical_output)

log_prob = generative_inference_network.base_distribution.log_prob(output)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_networks/test_summary_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,8 @@ def test_variable_set_size(summary_network, random_set):
summary_network.build(keras.ops.shape(random_set))

# run with another set size
for _ in range(10):
for s in [3, 4, 5]:
b = keras.ops.shape(random_set)[0]
s = np.random.randint(1, 10)
new_input = keras.ops.zeros((b, s, keras.ops.shape(random_set)[2]))
summary_network(new_input)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_workflows/test_basic_workflow.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import bayesflow as bf


def test_classifier_two_sample_test(inference_network, summary_network):
def test_basic_workflow(inference_network, summary_network):
workflow = bf.BasicWorkflow(
inference_network=inference_network,
summary_network=summary_network,
Expand Down