From 4e130cdf40eeaf6ca5770ace3d3738b1b5e56006 Mon Sep 17 00:00:00 2001 From: Jerry Date: Fri, 18 Apr 2025 15:46:06 -0400 Subject: [PATCH 01/14] Enhance loss plot, including viridisifying the val loss --- bayesflow/diagnostics/plots/loss.py | 69 ++++++++++++++++++++-------- bayesflow/utils/plot_utils.py | 71 +++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+), 20 deletions(-) diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index 4f1f90de4..d800e2693 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -7,21 +7,26 @@ import keras.src.callbacks -from ...utils.plot_utils import make_figure, add_titles_and_labels +from matplotlib.colors import Normalize +from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line, gradient_legend def loss( history: keras.callbacks.History, train_key: str = "loss", val_key: str = "val_loss", - moving_average: bool = False, + moving_average: bool = True, per_training_step: bool = False, - ma_window_fraction: float = 0.01, + moving_average_span: int = 10, figsize: Sequence[float] = None, train_color: str = "#132a70", - val_color: str = "black", + val_color: str = None, + val_colormap: str = 'viridis', lw_train: float = 2.0, lw_val: float = 3.0, + val_marker_type: str = "o", + val_marker_size: int = 34, + grid_alpha: float = 0.2, legend_fontsize: int = 14, label_fontsize: int = 14, title_fontsize: int = 16, @@ -39,7 +44,7 @@ def loss( val_key : str, optional, default: "val_loss" The validation loss key to look for in the history moving_average : bool, optional, default: False - A flag for adding a moving average line of the train_losses. + A flag for adding an exponential moving average line of the train_losses. per_training_step : bool, optional, default: False A flag for making loss trajectory detailed (to training steps) rather than per epoch. ma_window_fraction : int, optional, default: 0.01 @@ -99,27 +104,51 @@ def loss( # Loop through loss entries and populate plot for i, ax in enumerate(axes.flat): # Plot train curve - ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training") - if moving_average and train_losses.columns[i] == "Loss": - moving_average_window = int(train_losses.shape[0] * ma_window_fraction) - smoothed_loss = train_losses.iloc[:, i].rolling(window=moving_average_window).mean() + ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.2, label="Training") + if moving_average: + smoothed_loss = train_losses.iloc[:, 0].ewm(span=moving_average_span, adjust=True).mean() ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") # Plot optional val curve if val_losses is not None: - if i < val_losses.shape[1]: - ax.plot( - val_step_index, - val_losses.iloc[:, i], - linestyle="--", - marker="o", - color=val_color, - lw=lw_val, - label="Validation", - ) + if val_color is not None: + ax.plot( + val_step_index, + val_losses.iloc[:, 0], + linestyle="--", + marker=val_marker_type, + color=val_color, + lw=lw_val, + label="Validation", + ) + else: + # Create line segments between each epoch + points = np.array([val_step_index, val_losses.iloc[:,0]]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + # Normalize color based on loss values + lc = gradient_line( + val_step_index, + val_losses.iloc[:,0], + c=val_step_index, + cmap=val_colormap, + lw=lw_val, + ax=ax + ) + scatter = ax.scatter( + val_step_index, + val_losses.iloc[:,0], + c=val_step_index, + cmap=val_colormap, + marker=val_marker_type, + s=val_marker_size, + zorder=10, + edgecolors='none', + label='Validation' + ) sns.despine(ax=ax) - ax.grid(alpha=0.5) + ax.grid(alpha=grid_alpha) # Only add legend if there is a validation curve if val_losses is not None or moving_average: diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index 857ca0b23..cb6c3dddc 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -4,6 +4,11 @@ import matplotlib.pyplot as plt import seaborn as sns +from matplotlib.collections import LineCollection +from matplotlib.colors import Normalize +from matplotlib.patches import Rectangle +from matplotlib.legend_handler import HandlerPatch + from .validators import check_estimates_prior_shapes from .dict_utils import dicts_to_arrays @@ -260,3 +265,69 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray): alpha=0.9, linestyle="dashed", ) + + +def gradient_line(x, y, c=None, cmap='viridis', lw=2, ax=None): + """ + Plot a 1D line with color gradient determined by `c` (same shape as x and y). + """ + if ax is None: + ax = plt.gca() + + # Default color value = y + if c is None: + c = y + + # Create segments for LineCollection + points = np.array([x, y]).T.reshape(-1, 1, 2) + segments = np.concatenate([points[:-1], points[1:]], axis=1) + + norm = Normalize(np.min(c), np.max(c)) + lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw) + + ax.add_collection(lc) + ax.set_xlim(np.min(x), np.max(x)) + ax.set_ylim(np.min(y), np.max(y)) + return lc + + +def gradient_legend(ax, label, cmap, norm, loc='upper right'): + """ + Adds a single gradient swatch to the legend of the given Axes. + + Parameters: + - ax: matplotlib Axes + - label: str, label to display in the legend + - cmap: matplotlib colormap + - norm: matplotlib Normalize object + - loc: legend location (default 'upper right') + """ + + # Custom dummy handle to represent the gradient + class _GradientSwatch(Rectangle): pass + + # Custom legend handler that draws a horizontal gradient + class _HandlerGradient(HandlerPatch): + def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans): + gradient = np.linspace(0, 1, 256).reshape(1, -1) + im = ax.imshow( + gradient, + aspect='auto', + extent=[xdescent, xdescent + width, ydescent, ydescent + height], + transform=trans, + cmap=cmap, + norm=norm + ) + return [im] + + # Add to existing legend entries + handles, labels = ax.get_legend_handles_labels() + handles.append(_GradientSwatch((0, 0), 1, 1)) + labels.append(label) + + ax.legend( + handles=handles, + labels=labels, + loc=loc, + handler_map={_GradientSwatch: _HandlerGradient()} + ) From 06c55652f5b78b46d21237a3d87cf1955b97953a Mon Sep 17 00:00:00 2001 From: Jerry Date: Fri, 18 Apr 2025 16:29:01 -0400 Subject: [PATCH 02/14] Add docs, remove unused code --- bayesflow/diagnostics/plots/loss.py | 83 ++++++++++++++--------------- bayesflow/utils/plot_utils.py | 18 +++---- 2 files changed, 46 insertions(+), 55 deletions(-) diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index d800e2693..0d6a3f9c8 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -7,8 +7,7 @@ import keras.src.callbacks -from matplotlib.colors import Normalize -from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line, gradient_legend +from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line def loss( @@ -16,12 +15,11 @@ def loss( train_key: str = "loss", val_key: str = "val_loss", moving_average: bool = True, - per_training_step: bool = False, moving_average_span: int = 10, figsize: Sequence[float] = None, train_color: str = "#132a70", val_color: str = None, - val_colormap: str = 'viridis', + val_colormap: str = "viridis", lw_train: float = 2.0, lw_val: float = 3.0, val_marker_type: str = "o", @@ -45,9 +43,7 @@ def loss( The validation loss key to look for in the history moving_average : bool, optional, default: False A flag for adding an exponential moving average line of the train_losses. - per_training_step : bool, optional, default: False - A flag for making loss trajectory detailed (to training steps) rather than per epoch. - ma_window_fraction : int, optional, default: 0.01 + moving_average_span : int, optional, default: 0.01 Window size for the moving average as a fraction of total training steps. figsize : tuple or None, optional, default: None @@ -55,12 +51,20 @@ def loss( Inferred if ``None`` train_color : str, optional, default: '#8f2727' The color for the train loss trajectory - val_color : str, optional, default: black + val_color : str, optional, default: None + The color for the optional validation loss trajectory + val_colormap : str, optional, default: "viridis" The color for the optional validation loss trajectory lw_train : int, optional, default: 2 The linewidth for the training loss curve lw_val : int, optional, default: 3 The linewidth for the validation loss curve + val_marker_type : str, optional, default: o + The marker type for the validation loss curve + val_marker_size : int, optional, default: 34 + The marker size for the validation loss curve + grid_alpha : float, optional, default: 0.2 + The transparency of the background grid legend_fontsize : int, optional, default: 14 The font size of the legend text label_fontsize : int, optional, default: 14 @@ -111,41 +115,32 @@ def loss( # Plot optional val curve if val_losses is not None: - if val_color is not None: - ax.plot( - val_step_index, - val_losses.iloc[:, 0], - linestyle="--", - marker=val_marker_type, - color=val_color, - lw=lw_val, - label="Validation", - ) - else: - # Create line segments between each epoch - points = np.array([val_step_index, val_losses.iloc[:,0]]).T.reshape(-1, 1, 2) - segments = np.concatenate([points[:-1], points[1:]], axis=1) - - # Normalize color based on loss values - lc = gradient_line( - val_step_index, - val_losses.iloc[:,0], - c=val_step_index, - cmap=val_colormap, - lw=lw_val, - ax=ax - ) - scatter = ax.scatter( - val_step_index, - val_losses.iloc[:,0], - c=val_step_index, - cmap=val_colormap, - marker=val_marker_type, - s=val_marker_size, - zorder=10, - edgecolors='none', - label='Validation' - ) + if val_color is not None: + ax.plot( + val_step_index, + val_losses.iloc[:, 0], + linestyle="--", + marker=val_marker_type, + color=val_color, + lw=lw_val, + label="Validation", + ) + else: + # Make gradient lines + gradient_line( + val_step_index, val_losses.iloc[:, 0], c=val_step_index, cmap=val_colormap, lw=lw_val, ax=ax + ) + ax.scatter( + val_step_index, + val_losses.iloc[:, 0], + c=val_step_index, + cmap=val_colormap, + marker=val_marker_type, + s=val_marker_size, + zorder=10, + edgecolors="none", + label="Validation", + ) sns.despine(ax=ax) ax.grid(alpha=grid_alpha) @@ -160,7 +155,7 @@ def loss( num_row=num_row, num_col=1, title=["Loss Trajectory"], - xlabel="Training step #" if per_training_step else "Training epoch #", + xlabel="Training epoch #", ylabel="Value", title_fontsize=title_fontsize, label_fontsize=label_fontsize, diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index cb6c3dddc..deb5c5526 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -267,7 +267,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray): ) -def gradient_line(x, y, c=None, cmap='viridis', lw=2, ax=None): +def gradient_line(x, y, c=None, cmap="viridis", lw=2, ax=None): """ Plot a 1D line with color gradient determined by `c` (same shape as x and y). """ @@ -291,7 +291,7 @@ def gradient_line(x, y, c=None, cmap='viridis', lw=2, ax=None): return lc -def gradient_legend(ax, label, cmap, norm, loc='upper right'): +def gradient_legend(ax, label, cmap, norm, loc="upper right"): """ Adds a single gradient swatch to the legend of the given Axes. @@ -304,7 +304,8 @@ def gradient_legend(ax, label, cmap, norm, loc='upper right'): """ # Custom dummy handle to represent the gradient - class _GradientSwatch(Rectangle): pass + class _GradientSwatch(Rectangle): + pass # Custom legend handler that draws a horizontal gradient class _HandlerGradient(HandlerPatch): @@ -312,11 +313,11 @@ def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, gradient = np.linspace(0, 1, 256).reshape(1, -1) im = ax.imshow( gradient, - aspect='auto', + aspect="auto", extent=[xdescent, xdescent + width, ydescent, ydescent + height], transform=trans, cmap=cmap, - norm=norm + norm=norm, ) return [im] @@ -325,9 +326,4 @@ def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, handles.append(_GradientSwatch((0, 0), 1, 1)) labels.append(label) - ax.legend( - handles=handles, - labels=labels, - loc=loc, - handler_map={_GradientSwatch: _HandlerGradient()} - ) + ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()}) From 722c8618df24433ec91cf46d7a2efc83db7aff9d Mon Sep 17 00:00:00 2001 From: Jerry Date: Fri, 18 Apr 2025 17:26:34 -0400 Subject: [PATCH 03/14] Add smoothing for validation loss too --- bayesflow/diagnostics/plots/loss.py | 68 +++++++++++++++++++---------- bayesflow/utils/plot_utils.py | 37 ++++++++++++++-- 2 files changed, 80 insertions(+), 25 deletions(-) diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index 0d6a3f9c8..2f7e0a806 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -1,4 +1,4 @@ -from collections.abc import Sequence +from typing import Sequence import numpy as np import pandas as pd @@ -7,7 +7,7 @@ import keras.src.callbacks -from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line +from ...utils.plot_utils import make_figure, add_titles_and_labels, add_gradient_plot def loss( @@ -15,14 +15,15 @@ def loss( train_key: str = "loss", val_key: str = "val_loss", moving_average: bool = True, - moving_average_span: int = 10, + moving_average_alpha: float = 0.8, figsize: Sequence[float] = None, train_color: str = "#132a70", val_color: str = None, val_colormap: str = "viridis", lw_train: float = 2.0, lw_val: float = 3.0, - val_marker_type: str = "o", + marker: bool = True, + val_marker_type: str = ".", val_marker_size: int = 34, grid_alpha: float = 0.2, legend_fontsize: int = 14, @@ -43,9 +44,8 @@ def loss( The validation loss key to look for in the history moving_average : bool, optional, default: False A flag for adding an exponential moving average line of the train_losses. - moving_average_span : int, optional, default: 0.01 - Window size for the moving average as a fraction of total - training steps. + moving_average_alpha : int, optional, default: 0.8 + Smoothing factor for the moving average. figsize : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` @@ -54,11 +54,13 @@ def loss( val_color : str, optional, default: None The color for the optional validation loss trajectory val_colormap : str, optional, default: "viridis" - The color for the optional validation loss trajectory + The colormap for the optional validation loss trajectory lw_train : int, optional, default: 2 The linewidth for the training loss curve lw_val : int, optional, default: 3 The linewidth for the validation loss curve + marker : bool, optional, default: False + A flag for whether marker should be added in the validation loss trajectory val_marker_type : str, optional, default: o The marker type for the validation loss curve val_marker_size : int, optional, default: 34 @@ -108,10 +110,10 @@ def loss( # Loop through loss entries and populate plot for i, ax in enumerate(axes.flat): # Plot train curve - ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.2, label="Training") + ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.05, label="Training") if moving_average: - smoothed_loss = train_losses.iloc[:, 0].ewm(span=moving_average_span, adjust=True).mean() - ax.plot(train_step_index, smoothed_loss, color="grey", lw=lw_train, label="Training (Moving Average)") + smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean() + ax.plot(train_step_index, smoothed_train_loss, color="grey", lw=lw_train, label="Training (Moving Average)") # Plot optional val curve if val_losses is not None: @@ -120,27 +122,49 @@ def loss( val_step_index, val_losses.iloc[:, 0], linestyle="--", - marker=val_marker_type, + marker=val_marker_type if marker else None, color=val_color, lw=lw_val, + alpha=0.2, label="Validation", ) + if moving_average: + smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean() + ax.plot( + val_step_index, + smoothed_val_loss, + color=val_color, + lw=lw_val, + label="Validation (Moving Average)", + ) else: # Make gradient lines - gradient_line( - val_step_index, val_losses.iloc[:, 0], c=val_step_index, cmap=val_colormap, lw=lw_val, ax=ax - ) - ax.scatter( + add_gradient_plot( val_step_index, val_losses.iloc[:, 0], - c=val_step_index, - cmap=val_colormap, - marker=val_marker_type, - s=val_marker_size, - zorder=10, - edgecolors="none", + ax, + val_colormap, + lw_val, + marker, + val_marker_type, + val_marker_size, + alpha=0.05, label="Validation", ) + if moving_average: + smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean() + add_gradient_plot( + val_step_index, + smoothed_val_loss, + ax, + val_colormap, + lw_val, + marker, + val_marker_type, + val_marker_size, + alpha=1, + label="Validation (Moving Average)", + ) sns.despine(ax=ax) ax.grid(alpha=grid_alpha) diff --git a/bayesflow/utils/plot_utils.py b/bayesflow/utils/plot_utils.py index deb5c5526..c963b4dfb 100644 --- a/bayesflow/utils/plot_utils.py +++ b/bayesflow/utils/plot_utils.py @@ -267,7 +267,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray): ) -def gradient_line(x, y, c=None, cmap="viridis", lw=2, ax=None): +def gradient_line(x, y, c=None, cmap: str = "viridis", lw: float = 2.0, alpha: float = 1, ax=None): """ Plot a 1D line with color gradient determined by `c` (same shape as x and y). """ @@ -283,7 +283,7 @@ def gradient_line(x, y, c=None, cmap="viridis", lw=2, ax=None): segments = np.concatenate([points[:-1], points[1:]], axis=1) norm = Normalize(np.min(c), np.max(c)) - lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw) + lc = LineCollection(segments, array=c, cmap=cmap, norm=norm, linewidth=lw, alpha=alpha) ax.add_collection(lc) ax.set_xlim(np.min(x), np.max(x)) @@ -295,7 +295,8 @@ def gradient_legend(ax, label, cmap, norm, loc="upper right"): """ Adds a single gradient swatch to the legend of the given Axes. - Parameters: + Parameters + ---------- - ax: matplotlib Axes - label: str, label to display in the legend - cmap: matplotlib colormap @@ -327,3 +328,33 @@ def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, labels.append(label) ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()}) + + +def add_gradient_plot( + x, + y, + ax, + cmap: str = "viridis", + lw: float = 3.0, + marker: bool = True, + marker_type: str = "o", + marker_size: int = 34, + alpha: float = 1, + label: str = "Validation", +): + gradient_line(x, y, c=x, cmap=cmap, lw=lw, alpha=alpha, ax=ax) + + # Optionally add markers + if marker: + ax.scatter( + x, + y, + c=x, + cmap=cmap, + marker=marker_type, + s=marker_size, + zorder=10, + edgecolors="none", + label=label, + alpha=0.01, + ) From 2fdc5c305eedd532cbf77e072203cb951d6b2776 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 19 Apr 2025 16:38:45 -0400 Subject: [PATCH 04/14] add rudimentary test for loss plot --- tests/test_diagnostics/conftest.py | 17 +++++++++++++++++ .../test_diagnostics/test_diagnostics_plots.py | 9 +++++++++ 2 files changed, 26 insertions(+) diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 5103ea455..8e77d6729 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -1,3 +1,4 @@ +import keras import numpy as np import pytest from bayesflow.utils.numpy_utils import softmax @@ -61,3 +62,19 @@ def pred_models(true_models): pred_models = np.random.normal(loc=true_models) pred_models = softmax(pred_models, axis=-1) return pred_models + + +@pytest.fixture() +def history(): + h = keras.callbacks.History() + + step = np.linspace(0, 1, 10_000) + train_loss = (1.0 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape) + validation_loss = 0.1 + (0.75 - step) ** 2 + np.random.normal(loc=0, scale=0.02, size=step.shape) + + h.history = { + "loss": train_loss.tolist(), + "val_loss": validation_loss.tolist(), + } + + return h diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index f832535d1..cda8f9098 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -45,6 +45,15 @@ def test_calibration_histogram(random_estimates, random_targets): assert out.axes[0].title._text == "beta_0" +def test_loss(history): + import matplotlib.pyplot as plt + + out = bf.diagnostics.loss(history) + plt.show() + assert len(out.axes) == 1 + assert out.axes[0].title._text == "Loss Trajectory" + + def test_recovery(random_estimates, random_targets): # basic functionality: automatic variable names out = bf.diagnostics.plots.recovery(random_estimates, random_targets) From f9f807fdab7423dd2f0a6f15cfa0928af57ec79a Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 19 Apr 2025 16:58:54 -0400 Subject: [PATCH 05/14] fix comments --- bayesflow/diagnostics/plots/loss.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index 2f7e0a806..d78703726 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -1,4 +1,4 @@ -from typing import Sequence +from collections.abc import Sequence import numpy as np import pandas as pd @@ -14,8 +14,7 @@ def loss( history: keras.callbacks.History, train_key: str = "loss", val_key: str = "val_loss", - moving_average: bool = True, - moving_average_alpha: float = 0.8, + smoothing_factor: float = 0.8, figsize: Sequence[float] = None, train_color: str = "#132a70", val_color: str = None, From c0a4f25041a777e599db461fd4495b22db31443f Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sat, 19 Apr 2025 16:59:27 -0400 Subject: [PATCH 06/14] Refactor loss plotting logic and remove unused parameters. Simplified the loss plotting code by consolidating duplication and aligning the handling of smoothing logic. Removed unused arguments like markers and colormap, reducing potential confusion in the API. Updated comments and improved code readability for maintainability. --- bayesflow/diagnostics/plots/loss.py | 108 +++++++++++----------------- 1 file changed, 43 insertions(+), 65 deletions(-) diff --git a/bayesflow/diagnostics/plots/loss.py b/bayesflow/diagnostics/plots/loss.py index d78703726..baca73327 100644 --- a/bayesflow/diagnostics/plots/loss.py +++ b/bayesflow/diagnostics/plots/loss.py @@ -7,7 +7,7 @@ import keras.src.callbacks -from ...utils.plot_utils import make_figure, add_titles_and_labels, add_gradient_plot +from ...utils.plot_utils import make_figure, add_titles_and_labels def loss( @@ -17,13 +17,9 @@ def loss( smoothing_factor: float = 0.8, figsize: Sequence[float] = None, train_color: str = "#132a70", - val_color: str = None, - val_colormap: str = "viridis", + val_color: str = "black", lw_train: float = 2.0, - lw_val: float = 3.0, - marker: bool = True, - val_marker_type: str = ".", - val_marker_size: int = 34, + lw_val: float = 2.0, grid_alpha: float = 0.2, legend_fontsize: int = 14, label_fontsize: int = 14, @@ -41,10 +37,8 @@ def loss( The training loss key to look for in the history val_key : str, optional, default: "val_loss" The validation loss key to look for in the history - moving_average : bool, optional, default: False - A flag for adding an exponential moving average line of the train_losses. - moving_average_alpha : int, optional, default: 0.8 - Smoothing factor for the moving average. + smoothing_factor : float, optional, default: 0.8 + If greater than zero, smooth the loss curves by applying an exponential moving average. figsize : tuple or None, optional, default: None The figure size passed to the ``matplotlib`` constructor. Inferred if ``None`` @@ -52,18 +46,10 @@ def loss( The color for the train loss trajectory val_color : str, optional, default: None The color for the optional validation loss trajectory - val_colormap : str, optional, default: "viridis" - The colormap for the optional validation loss trajectory lw_train : int, optional, default: 2 The linewidth for the training loss curve lw_val : int, optional, default: 3 The linewidth for the validation loss curve - marker : bool, optional, default: False - A flag for whether marker should be added in the validation loss trajectory - val_marker_type : str, optional, default: o - The marker type for the validation loss curve - val_marker_size : int, optional, default: 34 - The marker size for the validation loss curve grid_alpha : float, optional, default: 0.2 The transparency of the background grid legend_fontsize : int, optional, default: 14 @@ -108,68 +94,60 @@ def loss( # Loop through loss entries and populate plot for i, ax in enumerate(axes.flat): - # Plot train curve - ax.plot(train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.05, label="Training") - if moving_average: - smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean() - ax.plot(train_step_index, smoothed_train_loss, color="grey", lw=lw_train, label="Training (Moving Average)") + if smoothing_factor > 0: + # plot unsmoothed train loss + ax.plot( + train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.3, label="Training" + ) + + # plot smoothed train loss + smoothed_train_loss = train_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean() + ax.plot( + train_step_index, + smoothed_train_loss, + color=train_color, + lw=lw_train, + alpha=0.8, + label="Training (Moving Average)", + ) + else: + # plot unsmoothed train loss + ax.plot( + train_step_index, train_losses.iloc[:, 0], color=train_color, lw=lw_train, alpha=0.8, label="Training" + ) # Plot optional val curve if val_losses is not None: if val_color is not None: - ax.plot( - val_step_index, - val_losses.iloc[:, 0], - linestyle="--", - marker=val_marker_type if marker else None, - color=val_color, - lw=lw_val, - alpha=0.2, - label="Validation", - ) - if moving_average: - smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean() + if smoothing_factor > 0: + # plot unsmoothed val loss + ax.plot( + val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.3, label="Validation" + ) + + # plot smoothed val loss + smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=1.0 - smoothing_factor, adjust=True).mean() ax.plot( val_step_index, smoothed_val_loss, color=val_color, lw=lw_val, + alpha=0.8, label="Validation (Moving Average)", ) - else: - # Make gradient lines - add_gradient_plot( - val_step_index, - val_losses.iloc[:, 0], - ax, - val_colormap, - lw_val, - marker, - val_marker_type, - val_marker_size, - alpha=0.05, - label="Validation", - ) - if moving_average: - smoothed_val_loss = val_losses.iloc[:, 0].ewm(alpha=moving_average_alpha, adjust=True).mean() - add_gradient_plot( - val_step_index, - smoothed_val_loss, - ax, - val_colormap, - lw_val, - marker, - val_marker_type, - val_marker_size, - alpha=1, - label="Validation (Moving Average)", + else: + # plot unsmoothed val loss + ax.plot( + val_step_index, val_losses.iloc[:, 0], color=val_color, lw=lw_val, alpha=0.8, label="Validation" ) sns.despine(ax=ax) ax.grid(alpha=grid_alpha) - # Only add legend if there is a validation curve - if val_losses is not None or moving_average: + ax.set_xlim(train_step_index[0], train_step_index[-1]) + + # Only add the legend if there are multiple curves + if val_losses is not None or smoothing_factor > 0: ax.legend(fontsize=legend_fontsize) # Add labels, titles, and set font sizes From f2468c7f37e8d0f40949cffabcdf792120cdbecc Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 14:57:17 -0400 Subject: [PATCH 07/14] Remove unused plt.show and ensure matplotlib figures are closed Removed an unnecessary `plt.show()` call in the test to streamline the code. Added a `pytest_runtest_teardown` hook to automatically close all matplotlib figures after tests to prevent resource leaks and improve test isolation. --- tests/test_diagnostics/conftest.py | 6 ++++++ tests/test_diagnostics/test_diagnostics_plots.py | 3 --- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 8e77d6729..04ad1c008 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -4,6 +4,12 @@ from bayesflow.utils.numpy_utils import softmax +def pytest_runtest_teardown(item, nextitem): + import matplotlib.pyplot as plt + + plt.close("all") + + @pytest.fixture() def var_names(): return [r"$\beta_0$", r"$\beta_1$", r"$\sigma$"] diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index cda8f9098..9472544b2 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -46,10 +46,7 @@ def test_calibration_histogram(random_estimates, random_targets): def test_loss(history): - import matplotlib.pyplot as plt - out = bf.diagnostics.loss(history) - plt.show() assert len(out.axes) == 1 assert out.axes[0].title._text == "Loss Trajectory" From d61b3d89a53ff5d174c682cafc861432cd43fd38 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:07:37 -0400 Subject: [PATCH 08/14] fix typo --- bayesflow/workflows/basic_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bayesflow/workflows/basic_workflow.py b/bayesflow/workflows/basic_workflow.py index fb9aa594d..2533bcd2e 100644 --- a/bayesflow/workflows/basic_workflow.py +++ b/bayesflow/workflows/basic_workflow.py @@ -348,7 +348,7 @@ def plot_default_diagnostics( ) -> dict[str, plt.Figure]: """ Generates default diagnostic plots to evaluate the quality of inference. The function produces several - diagnostic plots, including: + diagnostic plots, including - Loss history (if training history is available). - Parameter recovery plots. - Calibration ECDF plots. From 781372ec49326fa1515bb68db7db8b210cffee69 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:08:00 -0400 Subject: [PATCH 09/14] force plotting backend of tests to "agg" --- tests/conftest.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 315efd531..90eb16bf9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,6 @@ -import logging - import keras +import logging +import matplotlib import pytest BACKENDS = ["jax", "numpy", "tensorflow", "torch"] @@ -17,6 +17,16 @@ def pytest_runtest_setup(item): if test_backends and backend not in test_backends: pytest.skip(f"Skipping backend '{backend}' for test {item}, which is registered for backends {test_backends}.") + # use a non-GUI plotting backend for tests + matplotlib.use("Agg") + + +def pytest_runtest_teardown(item, nextitem): + import matplotlib.pyplot as plt + + # close all plots at the end of each test + plt.close("all") + def pytest_make_parametrize_id(config, val, argname): return f"{argname}={repr(val)}" From a93d75e1d67925b2e1709d6a92ae5e5e80e9b4e6 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:08:09 -0400 Subject: [PATCH 10/14] rename workflow test --- tests/test_workflows/test_basic_workflow.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_workflows/test_basic_workflow.py b/tests/test_workflows/test_basic_workflow.py index 6cab3be8e..9a1c7815f 100644 --- a/tests/test_workflows/test_basic_workflow.py +++ b/tests/test_workflows/test_basic_workflow.py @@ -1,7 +1,7 @@ import bayesflow as bf -def test_classifier_two_sample_test(inference_network, summary_network): +def test_basic_workflow(inference_network, summary_network): workflow = bf.BasicWorkflow( inference_network=inference_network, summary_network=summary_network, From b6db4939ab394d935dbd18b6314085e1ea9e7f96 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:08:20 -0400 Subject: [PATCH 11/14] speed up set size test --- tests/test_networks/test_summary_networks.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_networks/test_summary_networks.py b/tests/test_networks/test_summary_networks.py index a8770f3d7..44ff67ff5 100644 --- a/tests/test_networks/test_summary_networks.py +++ b/tests/test_networks/test_summary_networks.py @@ -49,9 +49,8 @@ def test_variable_set_size(summary_network, random_set): summary_network.build(keras.ops.shape(random_set)) # run with another set size - for _ in range(10): + for s in [3, 4, 5]: b = keras.ops.shape(random_set)[0] - s = np.random.randint(1, 10) new_input = keras.ops.zeros((b, s, keras.ops.shape(random_set)[2])) summary_network(new_input) From dc9e0d554df69a914d73885c59b29fd8a1b5c90f Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:08:37 -0400 Subject: [PATCH 12/14] improve error message for non-close outputs of inference networks in numerical density test --- tests/test_networks/test_inference_networks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_networks/test_inference_networks.py b/tests/test_networks/test_inference_networks.py index 4c3d0c5da..66015730d 100644 --- a/tests/test_networks/test_inference_networks.py +++ b/tests/test_networks/test_inference_networks.py @@ -98,7 +98,7 @@ def f(x): numerical_output, numerical_jacobian = jacobian(f, random_samples, return_output=True) # output should be identical, otherwise this test does not work (e.g. for stochastic networks) - assert keras.ops.all(keras.ops.isclose(output, numerical_output)) + assert_allclose(output, numerical_output) log_prob = generative_inference_network.base_distribution.log_prob(output) From a03598b609d274946dfa45f40f8e9c5933477d68 Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:08:53 -0400 Subject: [PATCH 13/14] move teardown to global conftest --- tests/test_diagnostics/conftest.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/tests/test_diagnostics/conftest.py b/tests/test_diagnostics/conftest.py index 04ad1c008..8e77d6729 100644 --- a/tests/test_diagnostics/conftest.py +++ b/tests/test_diagnostics/conftest.py @@ -4,12 +4,6 @@ from bayesflow.utils.numpy_utils import softmax -def pytest_runtest_teardown(item, nextitem): - import matplotlib.pyplot as plt - - plt.close("all") - - @pytest.fixture() def var_names(): return [r"$\beta_0$", r"$\beta_1$", r"$\sigma$"] From 72d92710f6609e6a19a0d98a62b79ea46e77cd7e Mon Sep 17 00:00:00 2001 From: LarsKue Date: Sun, 20 Apr 2025 16:09:08 -0400 Subject: [PATCH 14/14] add test for non-gui backend --- tests/test_diagnostics/test_diagnostics_plots.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/tests/test_diagnostics/test_diagnostics_plots.py b/tests/test_diagnostics/test_diagnostics_plots.py index 9472544b2..f9f29492f 100644 --- a/tests/test_diagnostics/test_diagnostics_plots.py +++ b/tests/test_diagnostics/test_diagnostics_plots.py @@ -6,6 +6,15 @@ def num_variables(x: dict): return sum(arr.shape[-1] for arr in x.values()) +def test_backend(): + import matplotlib.pyplot as plt + + # if the local testing backend is not Agg + # then you may run into issues once you run workflow tests + # on GitHub, since these use the Agg backend + assert plt.get_backend() == "Agg" + + def test_calibration_ecdf(random_estimates, random_targets, var_names): # basic functionality: automatic variable names out = bf.diagnostics.plots.calibration_ecdf(random_estimates, random_targets)