Skip to content

Commit 2e58100

Browse files
committed
Reformat mode subplots
1 parent dd246c6 commit 2e58100

File tree

1 file changed

+24
-18
lines changed

1 file changed

+24
-18
lines changed

pydmd/bopdmd.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1458,39 +1458,45 @@ def plot_mode_uq(
14581458
modes = np.average(modes.reshape(d, nd // d, r), axis=0)
14591459
modes_std = np.average(modes_std.reshape(d, nd // d, r), axis=0)
14601460

1461+
# Define the subplot grid.
14611462
rows = 2 * int(np.ceil(modes.shape[-1] / cols))
1462-
fig, axes = plt.subplots(rows, cols, figsize=figsize, dpi=dpi)
1463-
avg_axes = [ax for axes_list in axes[::2] for ax in axes_list]
1464-
std_axes = [ax for axes_list in axes[1::2] for ax in axes_list]
1465-
avg_axes = avg_axes[:modes.shape[-1]]
1466-
std_axes = std_axes[:modes.shape[-1]]
1467-
1468-
for i, (ax_avg, ax_std, mode, mode_std) in enumerate(
1469-
zip(avg_axes, std_axes, modes.T, modes_std.T)
1470-
):
1471-
ax_avg.set_title(f"Mode {i + 1}")
1472-
ax_std.set_title("Mode Standard Deviation")
1473-
1463+
plt.figure(figsize=figsize, dpi=dpi)
1464+
all_inds = np.arange(rows * cols).reshape(rows, cols)
1465+
avg_inds = all_inds[::2].flatten()
1466+
std_inds = all_inds[1::2].flatten()
1467+
1468+
for i, (mode, mode_std) in enumerate(zip(modes.T, modes_std.T)):
1469+
# Plot the average mode.
1470+
plt.subplot(rows, cols, avg_inds[i])
1471+
plt.title(f"Mode {i + 1}")
14741472
if len(modes_shape) == 1:
14751473
# Plot modes in 1-D.
1476-
ax_avg.plot(x, mode.real, c="tab:blue")
1477-
ax_std.plot(x, mode_std, c="tab:red")
1474+
plt.plot(x, mode.real, c="tab:blue")
14781475
else:
14791476
# Plot modes in 2-D.
1480-
im_avg = ax_avg.pcolormesh(
1477+
plt.pcolormesh(
14811478
xgrid,
14821479
ygrid,
14831480
mode.reshape(*modes_shape, order=order).real,
14841481
cmap="viridis",
14851482
)
1486-
im_std = ax_std.pcolormesh(
1483+
plt.colorbar()
1484+
1485+
# Plot the mode standard deviation.
1486+
plt.subplot(rows, cols, std_inds[i])
1487+
plt.title("Mode Standard Deviation")
1488+
if len(modes_shape) == 1:
1489+
# Plot modes in 1-D.
1490+
plt.plot(x, mode_std, c="tab:red")
1491+
else:
1492+
# Plot modes in 2-D.
1493+
plt.pcolormesh(
14871494
xgrid,
14881495
ygrid,
14891496
mode_std.reshape(*modes_shape, order=order),
14901497
cmap="inferno",
14911498
)
1492-
fig.colorbar(im_avg, ax=ax_avg)
1493-
fig.colorbar(im_std, ax=ax_std)
1499+
plt.colorbar()
14941500

14951501
plt.suptitle("DMD Modes")
14961502
plt.tight_layout()

0 commit comments

Comments
 (0)