Skip to content

Commit 06c5565

Browse files
committed
Add docs, remove unused code
1 parent 4e130cd commit 06c5565

File tree

2 files changed

+46
-55
lines changed

2 files changed

+46
-55
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 39 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,19 @@
77

88
import keras.src.callbacks
99

10-
from matplotlib.colors import Normalize
11-
from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line, gradient_legend
10+
from ...utils.plot_utils import make_figure, add_titles_and_labels, gradient_line
1211

1312

1413
def loss(
1514
history: keras.callbacks.History,
1615
train_key: str = "loss",
1716
val_key: str = "val_loss",
1817
moving_average: bool = True,
19-
per_training_step: bool = False,
2018
moving_average_span: int = 10,
2119
figsize: Sequence[float] = None,
2220
train_color: str = "#132a70",
2321
val_color: str = None,
24-
val_colormap: str = 'viridis',
22+
val_colormap: str = "viridis",
2523
lw_train: float = 2.0,
2624
lw_val: float = 3.0,
2725
val_marker_type: str = "o",
@@ -45,22 +43,28 @@ def loss(
4543
The validation loss key to look for in the history
4644
moving_average : bool, optional, default: False
4745
A flag for adding an exponential moving average line of the train_losses.
48-
per_training_step : bool, optional, default: False
49-
A flag for making loss trajectory detailed (to training steps) rather than per epoch.
50-
ma_window_fraction : int, optional, default: 0.01
46+
moving_average_span : int, optional, default: 0.01
5147
Window size for the moving average as a fraction of total
5248
training steps.
5349
figsize : tuple or None, optional, default: None
5450
The figure size passed to the ``matplotlib`` constructor.
5551
Inferred if ``None``
5652
train_color : str, optional, default: '#8f2727'
5753
The color for the train loss trajectory
58-
val_color : str, optional, default: black
54+
val_color : str, optional, default: None
55+
The color for the optional validation loss trajectory
56+
val_colormap : str, optional, default: "viridis"
5957
The color for the optional validation loss trajectory
6058
lw_train : int, optional, default: 2
6159
The linewidth for the training loss curve
6260
lw_val : int, optional, default: 3
6361
The linewidth for the validation loss curve
62+
val_marker_type : str, optional, default: o
63+
The marker type for the validation loss curve
64+
val_marker_size : int, optional, default: 34
65+
The marker size for the validation loss curve
66+
grid_alpha : float, optional, default: 0.2
67+
The transparency of the background grid
6468
legend_fontsize : int, optional, default: 14
6569
The font size of the legend text
6670
label_fontsize : int, optional, default: 14
@@ -111,41 +115,32 @@ def loss(
111115

112116
# Plot optional val curve
113117
if val_losses is not None:
114-
if val_color is not None:
115-
ax.plot(
116-
val_step_index,
117-
val_losses.iloc[:, 0],
118-
linestyle="--",
119-
marker=val_marker_type,
120-
color=val_color,
121-
lw=lw_val,
122-
label="Validation",
123-
)
124-
else:
125-
# Create line segments between each epoch
126-
points = np.array([val_step_index, val_losses.iloc[:,0]]).T.reshape(-1, 1, 2)
127-
segments = np.concatenate([points[:-1], points[1:]], axis=1)
128-
129-
# Normalize color based on loss values
130-
lc = gradient_line(
131-
val_step_index,
132-
val_losses.iloc[:,0],
133-
c=val_step_index,
134-
cmap=val_colormap,
135-
lw=lw_val,
136-
ax=ax
137-
)
138-
scatter = ax.scatter(
139-
val_step_index,
140-
val_losses.iloc[:,0],
141-
c=val_step_index,
142-
cmap=val_colormap,
143-
marker=val_marker_type,
144-
s=val_marker_size,
145-
zorder=10,
146-
edgecolors='none',
147-
label='Validation'
148-
)
118+
if val_color is not None:
119+
ax.plot(
120+
val_step_index,
121+
val_losses.iloc[:, 0],
122+
linestyle="--",
123+
marker=val_marker_type,
124+
color=val_color,
125+
lw=lw_val,
126+
label="Validation",
127+
)
128+
else:
129+
# Make gradient lines
130+
gradient_line(
131+
val_step_index, val_losses.iloc[:, 0], c=val_step_index, cmap=val_colormap, lw=lw_val, ax=ax
132+
)
133+
ax.scatter(
134+
val_step_index,
135+
val_losses.iloc[:, 0],
136+
c=val_step_index,
137+
cmap=val_colormap,
138+
marker=val_marker_type,
139+
s=val_marker_size,
140+
zorder=10,
141+
edgecolors="none",
142+
label="Validation",
143+
)
149144

150145
sns.despine(ax=ax)
151146
ax.grid(alpha=grid_alpha)
@@ -160,7 +155,7 @@ def loss(
160155
num_row=num_row,
161156
num_col=1,
162157
title=["Loss Trajectory"],
163-
xlabel="Training step #" if per_training_step else "Training epoch #",
158+
xlabel="Training epoch #",
164159
ylabel="Value",
165160
title_fontsize=title_fontsize,
166161
label_fontsize=label_fontsize,

bayesflow/utils/plot_utils.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ def make_quadratic(ax: plt.Axes, x_data: np.ndarray, y_data: np.ndarray):
267267
)
268268

269269

270-
def gradient_line(x, y, c=None, cmap='viridis', lw=2, ax=None):
270+
def gradient_line(x, y, c=None, cmap="viridis", lw=2, ax=None):
271271
"""
272272
Plot a 1D line with color gradient determined by `c` (same shape as x and y).
273273
"""
@@ -291,7 +291,7 @@ def gradient_line(x, y, c=None, cmap='viridis', lw=2, ax=None):
291291
return lc
292292

293293

294-
def gradient_legend(ax, label, cmap, norm, loc='upper right'):
294+
def gradient_legend(ax, label, cmap, norm, loc="upper right"):
295295
"""
296296
Adds a single gradient swatch to the legend of the given Axes.
297297
@@ -304,19 +304,20 @@ def gradient_legend(ax, label, cmap, norm, loc='upper right'):
304304
"""
305305

306306
# Custom dummy handle to represent the gradient
307-
class _GradientSwatch(Rectangle): pass
307+
class _GradientSwatch(Rectangle):
308+
pass
308309

309310
# Custom legend handler that draws a horizontal gradient
310311
class _HandlerGradient(HandlerPatch):
311312
def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans):
312313
gradient = np.linspace(0, 1, 256).reshape(1, -1)
313314
im = ax.imshow(
314315
gradient,
315-
aspect='auto',
316+
aspect="auto",
316317
extent=[xdescent, xdescent + width, ydescent, ydescent + height],
317318
transform=trans,
318319
cmap=cmap,
319-
norm=norm
320+
norm=norm,
320321
)
321322
return [im]
322323

@@ -325,9 +326,4 @@ def create_artists(self, legend, orig_handle, xdescent, ydescent, width, height,
325326
handles.append(_GradientSwatch((0, 0), 1, 1))
326327
labels.append(label)
327328

328-
ax.legend(
329-
handles=handles,
330-
labels=labels,
331-
loc=loc,
332-
handler_map={_GradientSwatch: _HandlerGradient()}
333-
)
329+
ax.legend(handles=handles, labels=labels, loc=loc, handler_map={_GradientSwatch: _HandlerGradient()})

0 commit comments

Comments
 (0)