Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 71 additions & 23 deletions bayesflow/diagnostics/plots/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from collections.abc import Sequence
from typing import Sequence

import numpy as np
import pandas as pd
Expand All @@ -7,21 +7,25 @@

import keras.src.callbacks

from ...utils.plot_utils import make_figure, add_titles_and_labels
from ...utils.plot_utils import make_figure, add_titles_and_labels, add_gradient_plot


def loss(
history: keras.callbacks.History,
train_key: str = "loss",
val_key: str = "val_loss",
moving_average: bool = False,
per_training_step: bool = False,
ma_window_fraction: float = 0.01,
moving_average: bool = True,
moving_average_alpha: float = 0.8,
figsize: Sequence[float] = None,
train_color: str = "#132a70",
val_color: str = "black",
val_color: str = None,
val_colormap: str = "viridis",
lw_train: float = 2.0,
lw_val: float = 3.0,
marker: bool = True,
val_marker_type: str = ".",
val_marker_size: int = 34,
grid_alpha: float = 0.2,
legend_fontsize: int = 14,
label_fontsize: int = 14,
title_fontsize: int = 16,
Expand All @@ -39,23 +43,30 @@
val_key : str, optional, default: "val_loss"
The validation loss key to look for in the history
moving_average : bool, optional, default: False
A flag for adding a moving average line of the train_losses.
per_training_step : bool, optional, default: False
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
ma_window_fraction : int, optional, default: 0.01
Window size for the moving average as a fraction of total
training steps.
A flag for adding an exponential moving average line of the train_losses.
moving_average_alpha : int, optional, default: 0.8
Smoothing factor for the moving average.
figsize : tuple or None, optional, default: None
The figure size passed to the ``matplotlib`` constructor.
Inferred if ``None``
train_color : str, optional, default: '#8f2727'
The color for the train loss trajectory
val_color : str, optional, default: black
val_color : str, optional, default: None
The color for the optional validation loss trajectory
val_colormap : str, optional, default: "viridis"
The colormap for the optional validation loss trajectory
lw_train : int, optional, default: 2
The linewidth for the training loss curve
lw_val : int, optional, default: 3
The linewidth for the validation loss curve
marker : bool, optional, default: False
A flag for whether marker should be added in the validation loss trajectory
val_marker_type : str, optional, default: o
The marker type for the validation loss curve
val_marker_size : int, optional, default: 34
The marker size for the validation loss curve
grid_alpha : float, optional, default: 0.2
The transparency of the background grid
legend_fontsize : int, optional, default: 14
The font size of the legend text
label_fontsize : int, optional, default: 14
Expand Down Expand Up @@ -99,27 +110,64 @@
# Loop through loss entries and populate plot
for i, ax in enumerate(axes.flat):
# Plot train curve
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
if moving_average and train_losses.columns[i] == "Loss":
moving_average_window = int(train_losses.shape[0] * ma_window_fraction)
smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean()
ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)")
ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.05, label="Training")
if moving_average:
smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
ax.plot(train_step_index, smoothed_train_loss, color="grey", lw=lw_train, label="Training (Moving Average)")

# Plot optional val curve
if val_losses is not None:
if i < val_losses.shape[1]:
if val_color is not None:

Check warning on line 120 in bayesflow/diagnostics/plots/loss.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/diagnostics/plots/loss.py#L120

Added line #L120 was not covered by tests
ax.plot(
val_step_index,
val_losses.iloc[:, i],
val_losses.iloc[:, 0],
linestyle="--",
marker="o",
marker=val_marker_type if marker else None,
color=val_color,
lw=lw_val,
alpha=0.2,
label="Validation",
)
if moving_average:
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
ax.plot(

Check warning on line 133 in bayesflow/diagnostics/plots/loss.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/diagnostics/plots/loss.py#L131-L133

Added lines #L131 - L133 were not covered by tests
val_step_index,
smoothed_val_loss,
color=val_color,
lw=lw_val,
label="Validation (Moving Average)",
)
else:
# Make gradient lines
add_gradient_plot(

Check warning on line 142 in bayesflow/diagnostics/plots/loss.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/diagnostics/plots/loss.py#L142

Added line #L142 was not covered by tests
val_step_index,
val_losses.iloc[:, 0],
ax,
val_colormap,
lw_val,
marker,
val_marker_type,
val_marker_size,
alpha=0.05,
label="Validation",
)
if moving_average:
smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean()
add_gradient_plot(

Check warning on line 156 in bayesflow/diagnostics/plots/loss.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/diagnostics/plots/loss.py#L154-L156

Added lines #L154 - L156 were not covered by tests
val_step_index,
smoothed_val_loss,
ax,
val_colormap,
lw_val,
marker,
val_marker_type,
val_marker_size,
alpha=1,
label="Validation (Moving Average)",
)

sns.despine(ax=ax)
ax.grid(alpha=0.5)
ax.grid(alpha=grid_alpha)

# Only add legend if there is a validation curve
if val_losses is not None or moving_average:
Expand All @@ -131,7 +179,7 @@
num_row=num_row,
num_col=1,
title=["Loss Trajectory"],
xlabel="Training step #" if per_training_step else "Training epoch #",
xlabel="Training epoch #",
ylabel="Value",
title_fontsize=title_fontsize,
label_fontsize=label_fontsize,
Expand Down
98 changes: 98 additions & 0 deletions bayesflow/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib.collections import LineCollection
from matplotlib.colors import Normalize
from matplotlib.patches import Rectangle
from matplotlib.legend_handler import HandlerPatch

from .validators import check_estimates_prior_shapes
from .dict_utils import dicts_to_arrays

Expand Down Expand Up @@ -260,3 +265,96 @@
alpha=0.9,
linestyle="dashed",
)


def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None):
"""
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
"""
if ax is None:
ax = plt.gca()

Check warning on line 275 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L274-L275

Added lines #L274 - L275 were not covered by tests

# Default color value = y
if c is None:
c = y

Check warning on line 279 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L278-L279

Added lines #L278 - L279 were not covered by tests

# Create segments for LineCollection
points = np.array([x, y]).T.reshape(-1, 1, 2)
segments = np.concatenate([points[:-1], points[1:]], axis=1)

Check warning on line 283 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L282-L283

Added lines #L282 - L283 were not covered by tests

norm = Normalize(np.min(c), np.max(c))
lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha)

Check warning on line 286 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L285-L286

Added lines #L285 - L286 were not covered by tests

ax.add_collection(lc)
ax.set_xlim(np.min(x), np.max(x))
ax.set_ylim(np.min(y), np.max(y))
return lc

Check warning on line 291 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L288-L291

Added lines #L288 - L291 were not covered by tests


def gradient_legend(ax, label, cmap, norm, loc="upper right"):
"""
Adds a single gradient swatch to the legend of the given Axes.

Parameters
----------
- ax: matplotlib Axes
- label: str, label to display in the legend
- cmap: matplotlib colormap
- norm: matplotlib Normalize object
- loc: legend location (default 'upper right')
"""

# Custom dummy handle to represent the gradient
class _GradientSwatch(Rectangle):
pass

Check warning on line 309 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L308-L309

Added lines #L308 - L309 were not covered by tests

# Custom legend handler that draws a horizontal gradient
class _HandlerGradient(HandlerPatch):
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
gradient = np.linspace(0, 1, 256).reshape(1, -1)
im = ax.imshow(

Check warning on line 315 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L312-L315

Added lines #L312 - L315 were not covered by tests
gradient,
aspect="auto",
extent=[xdescent, xdescent + width, ydescent, ydescent + height],
transform=trans,
cmap=cmap,
norm=norm,
)
return [im]

Check warning on line 323 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L323

Added line #L323 was not covered by tests

# Add to existing legend entries
handles, labels = ax.get_legend_handles_labels()
handles.append(_GradientSwatch((0, 0), 1, 1))
labels.append(label)

Check warning on line 328 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L326-L328

Added lines #L326 - L328 were not covered by tests

ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()})

Check warning on line 330 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L330

Added line #L330 was not covered by tests


def add_gradient_plot(
x,
y,
ax,
cmap: str = "viridis",
lw: float = 3.0,
marker: bool = True,
marker_type: str = "o",
marker_size: int = 34,
alpha: float = 1,
label: str = "Validation",
):
gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax)

Check warning on line 345 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L345

Added line #L345 was not covered by tests

# Optionally add markers
if marker:
ax.scatter(

Check warning on line 349 in bayesflow/utils/plot_utils.py

View check run for this annotation

Codecov / codecov/patch

bayesflow/utils/plot_utils.py#L348-L349

Added lines #L348 - L349 were not covered by tests
x,
y,
c=x,
cmap=cmap,
marker=marker_type,
s=marker_size,
zorder=10,
edgecolors="none",
label=label,
alpha=0.01,
)