Skip to content

Commit 8210610

Browse files
authored
Fix single parameter plots (#303)
* Fix single parameter plotting * Remove superfluous atleast_1d
1 parent 5da3759 commit 8210610

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

bayesflow/diagnostics/plots/loss.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def loss(
9696
val_step_index = val_step_index[: val_losses.shape[0]]
9797

9898
# Loop through loss entries and populate plot
99-
looper = [axes] if num_row == 1 else axes.flat
100-
for i, ax in enumerate(looper):
99+
for i, ax in enumerate(axes.flat):
101100
# Plot train curve
102101
ax.plot(train_step_index, train_losses.iloc[:, i], color=train_color, lw=lw_train, alpha=0.9, label="Training")
103102
if moving_average and train_losses.columns[i] == "Loss":
@@ -127,7 +126,7 @@ def loss(
127126

128127
# Add labels, titles, and set font sizes
129128
add_titles_and_labels(
130-
axes=np.atleast_1d(axes),
129+
axes=axes,
131130
num_row=num_row,
132131
num_col=1,
133132
title=["Loss Trajectory"],

bayesflow/diagnostics/plots/recovery.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def recovery(
7777
if uncertainty_agg is not None:
7878
u = uncertainty_agg(targets, axis=1)
7979

80-
for i, ax in enumerate(np.atleast_1d(plot_data["axes"].flat)):
80+
for i, ax in enumerate(plot_data["axes"].flat):
8181
if i >= plot_data["num_variables"]:
8282
break
8383

bayesflow/utils/plot_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def make_figure(num_row: int = None, num_col: int = None, figsize: tuple = None)
130130
figsize = (int(5 * num_col), int(5 * num_row))
131131

132132
f, axes = plt.subplots(num_row, num_col, figsize=figsize)
133+
axes = np.atleast_1d(axes)
133134

134135
return f, axes
135136

0 commit comments

Comments
 (0)