diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index 4f1f90de4..baca73327 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -14,14 +14,13 @@ def loss( history: keras.callbacks.History, train_key: str = "loss", val_key: str = "val_loss", - moving_average: bool = False, - per_training_step: bool = False, - ma_window_fraction: float = 0.01, + smoothing_factor: float = 0.8, figsize: Sequence[float] = None, train_color: str = "#132a70", val_color: str = "black", lw_train: float = 2.0, - lw_val: float = 3.0, + lw_val: float = 2.0, + grid_alpha: float = 0.2, legend_fontsize: int = 14, label_fontsize: int = 14, title_fontsize: int = 16, @@ -38,24 +37,21 @@ def loss( The training loss key to look for in the history val_key : str, optional, default: "val_loss" The validation loss key to look for in the history - moving_average : bool, optional, default: False - A flag for adding a moving average line of the train_losses. - per_training_step : bool, optional, default: False - A flag for making loss trajectory detailed (to training steps) rather than per epoch. - ma_window_fraction : int, optional, default: 0.01 - Window size for the moving average as a fraction of total - training steps. + smoothing_factor : float, optional, default: 0.8 + If greater than zero, smooth the loss curves by applying an exponential moving average. figsize : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` train_color : str, optional, default: '#8f2727' The color for the train loss trajectory - val_color : str, optional, default: black + val_color : str, optional, default: None The color for the optional validation loss trajectory lw_train : int, optional, default: 2 The linewidth for the training loss curve lw_val : int, optional, default: 3 The linewidth for the validation loss curve + grid_alpha : float, optional, default: 0.2 + The transparency of the background grid legend_fontsize : int, optional, default: 14 The font size of the legend text label_fontsize : int, optional, default: 14 @@ -98,31 +94,60 @@ def loss( # Loop through loss entries and populate plot for i, ax in enumerate(axes.flat): - # Plot train curve - ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") - if moving_average and train_losses.columns[i] == "Loss": - moving_average_window = int(train_losses.shape[0] * ma_window_fraction) - smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() - ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") + if smoothing_factor > 0: + # plot unsmoothed train loss + ax.plot( + train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.3, label="Training" + ) + + # plot smoothed train loss + smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean() + ax.plot( + train_step_index, + smoothed_train_loss, + color=train_color, + lw=lw_train, + alpha=0.8, + label="Training (Moving Average)", + ) + else: + # plot unsmoothed train loss + ax.plot( + train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training" + ) # Plot optional val curve if val_losses is not None: - if i < val_losses.shape[1]: - ax.plot( - val_step_index, - val_losses.iloc[:, i], - linestyle="--", - marker="o", - color=val_color, - lw=lw_val, - label="Validation", - ) + if val_color is not None: + if smoothing_factor > 0: + # plot unsmoothed val loss + ax.plot( + val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation" + ) + + # plot smoothed val loss + smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean() + ax.plot( + val_step_index, + smoothed_val_loss, + color=val_color, + lw=lw_val, + alpha=0.8, + label="Validation (Moving Average)", + ) + else: + # plot unsmoothed val loss + ax.plot( + val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation" + ) sns.despine(ax=ax) - ax.grid(alpha=0.5) + ax.grid(alpha=grid_alpha) - # Only add legend if there is a validation curve - if val_losses is not None or moving_average: + ax.set_xlim(train_step_index[0], train_step_index[-1]) + + # Only add the legend if there are multiple curves + if val_losses is not None or smoothing_factor > 0: ax.legend(fontsize=legend_fontsize) # Add labels, titles, and set font sizes @@ -131,7 +156,7 @@ def loss( num_row=num_row, num_col=1, title=["Loss Trajectory"], - xlabel="Training step #" if per_training_step else "Training epoch #", + xlabel="Training epoch #", ylabel="Value", title_fontsize=title_fontsize, label_fontsize=label_fontsize, diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 857ca0b23..c963b4dfb 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -4,6 +4,11 @@ import matplotlib.pyplot as plt import seaborn as sns +from matplotlib.collections import LineCollection +from matplotlib.colors import Normalize +from matplotlib.patches import Rectangle +from matplotlib.legend_handler import HandlerPatch + from .validators import check_estimates_prior_shapes from .dict_utils import dicts_to_arrays @@ -260,3 +265,96 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray): alpha=0.9, linestyle="dashed", ) + + +def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None): + """ + Plot a 1D line with color gradient determined by `c` (same shape as x and y). + """ + if ax is None: + ax = plt.gca() + + # Default color value = y + if c is None: + c = y + + # Create segments for LineCollection + points = np.array([x, y]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + norm = Normalize(np.min(c), np.max(c)) + lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha) + + ax.add_collection(lc) + ax.set_xlim(np.min(x), np.max(x)) + ax.set_ylim(np.min(y), np.max(y)) + return lc + + +def gradient_legend(ax, label, cmap, norm, loc="upper right"): + """ + Adds a single gradient swatch to the legend of the given Axes. + + Parameters + ---------- + - ax: matplotlib Axes + - label: str, label to display in the legend + - cmap: matplotlib colormap + - norm: matplotlib Normalize object + - loc: legend location (default 'upper right') + """ + + # Custom dummy handle to represent the gradient + class _GradientSwatch(Rectangle): + pass + + # Custom legend handler that draws a horizontal gradient + class _HandlerGradient(HandlerPatch): + def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans): + gradient = np.linspace(0, 1, 256).reshape(1, -1) + im = ax.imshow( + gradient, + aspect="auto", + extent=[xdescent, xdescent + width, ydescent, ydescent + height], + transform=trans, + cmap=cmap, + norm=norm, + ) + return [im] + + # Add to existing legend entries + handles, labels = ax.get_legend_handles_labels() + handles.append(_GradientSwatch((0, 0), 1, 1)) + labels.append(label) + + ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()}) + + +def add_gradient_plot( + x, + y, + ax, + cmap: str = "viridis", + lw: float = 3.0, + marker: bool = True, + marker_type: str = "o", + marker_size: int = 34, + alpha: float = 1, + label: str = "Validation", +): + gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax) + + # Optionally add markers + if marker: + ax.scatter( + x, + y, + c=x, + cmap=cmap, + marker=marker_type, + s=marker_size, + zorder=10, + edgecolors="none", + label=label, + alpha=0.01, + ) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index fb9aa594d..2533bcd2e 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -348,7 +348,7 @@ def plot_default_diagnostics( ) -> dict[str, plt.Figure]: """ Generates default diagnostic plots to evaluate the quality of inference. The function produces several - diagnostic plots, including: + diagnostic plots, including - Loss history (if training history is available). - Parameter recovery plots. - Calibration ECDF plots. diff --git a/tests/conftest.py b/tests/conftest.py index 315efd531..90eb16bf9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ -import logging - import keras +import logging +import matplotlib import pytest BACKENDS = ["jax", "numpy", "tensorflow", "torch"] @@ -17,6 +17,16 @@ def pytest_runtest_setup(item): if test_backends and backend not in test_backends: pytest.skip(f"Skipping backend '{backend}' for test {item}, which is registered for backends {test_backends}.") + # use a non-GUI plotting backend for tests + matplotlib.use("Agg") + + +def pytest_runtest_teardown(item, nextitem): + import matplotlib.pyplot as plt + + # close all plots at the end of each test + plt.close("all") + def pytest_make_parametrize_id(config, val, argname): return f"{argname}={repr(val)}" diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 5103ea455..8e77d6729 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest from bayesflow.utils.numpy_utils import softmax @@ -61,3 +62,19 @@ def pred_models(true_models): pred_models = np.random.normal(loc=true_models) pred_models = softmax(pred_models, axis=-1) return pred_models + + +@pytest.fixture() +def history(): + h = keras.callbacks.History() + + step = np.linspace(0, 1, 10_000) + train_loss = (1.0 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape) + validation_loss = 0.1 + (0.75 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape) + + h.history = { + "loss": train_loss.tolist(), + "val_loss": validation_loss.tolist(), + } + + return h diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index f832535d1..f9f29492f 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -6,6 +6,15 @@ def num_variables(x: dict): return sum(arr.shape[-1] for arr in x.values()) +def test_backend(): + import matplotlib.pyplot as plt + + # if the local testing backend is not Agg + # then you may run into issues once you run workflow tests + # on GitHub, since these use the Agg backend + assert plt.get_backend() == "Agg" + + def test_calibration_ecdf(random_estimates, random_targets, var_names): # basic functionality: automatic variable names out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets) @@ -45,6 +54,12 @@ def test_calibration_histogram(random_estimates, random_targets): assert out.axes[0].title._text == "beta_0" +def test_loss(history): + out = bf.diagnostics.loss(history) + assert len(out.axes) == 1 + assert out.axes[0].title._text == "Loss Trajectory" + + def test_recovery(random_estimates, random_targets): # basic functionality: automatic variable names out = bf.diagnostics.plots.recovery(random_estimates, random_targets) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 4c3d0c5da..66015730d 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -98,7 +98,7 @@ def f(x): numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True) # output should be identical, otherwise this test does not work (e.g. for stochastic networks) - assert keras.ops.all(keras.ops.isclose(output, numerical_output)) + assert_allclose(output, numerical_output) log_prob = generative_inference_network.base_distribution.log_prob(output) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index a8770f3d7..44ff67ff5 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -49,9 +49,8 @@ def test_variable_set_size(summary_network, random_set): summary_network.build(keras.ops.shape(random_set)) # run with another set size - for _ in range(10): + for s in [3, 4, 5]: b = keras.ops.shape(random_set)[0] - s = np.random.randint(1, 10) new_input = keras.ops.zeros((b, s, keras.ops.shape(random_set)[2])) summary_network(new_input) diff --git a/tests/test_workflows/test_basic_workflow.py b/tests/test_workflows/test_basic_workflow.py index 6cab3be8e..9a1c7815f 100644 --- a/tests/test_workflows/test_basic_workflow.py +++ b/tests/test_workflows/test_basic_workflow.py @@ -1,7 +1,7 @@ import bayesflow as bf -def test_classifier_two_sample_test(inference_network, summary_network): +def test_basic_workflow(inference_network, summary_network): workflow = bf.BasicWorkflow( inference_network=inference_network, summary_network=summary_network,