Skip to content

Commit cc9cae9

Browse files
committed
Incorporated feedback
1 parent 1f6fd99 commit cc9cae9

File tree

1 file changed

+46
-5
lines changed

1 file changed

+46
-5
lines changed

pydmd/bopdmd.py

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,7 @@ def forecast(self, t):
13941394

13951395
def plot_mode_uq(
13961396
self,
1397+
*,
13971398
x=None,
13981399
y=None,
13991400
d=1,
@@ -1402,6 +1403,8 @@ def plot_mode_uq(
14021403
cols=4,
14031404
figsize=None,
14041405
dpi=None,
1406+
plot_modes=None,
1407+
plot_complex_pairs=True,
14051408
):
14061409
"""
14071410
Plot BOP-DMD modes alongside their standard deviations.
@@ -1431,10 +1434,33 @@ def plot_mode_uq(
14311434
:type figsize: iterable
14321435
:param dpi: Figure resolution.
14331436
:type dpi: int
1437+
:param plot_modes: Number of leading modes to plot, or the indices of
1438+
the modes to plot. If `None`, then all available modes are plotted.
1439+
Note that if this parameter is given as a list of indices, it will
1440+
override the `plot_complex_pair` parameter.
1441+
:type plot_modes: int or iterable
1442+
:param plot_complex_pairs: Whether or not to omit one of the modes that
1443+
correspond with a complex conjugate pair of eigenvalues.
1444+
:type plot_complex_pairs: bool
14341445
"""
14351446
if self.modes_std is None:
14361447
raise ValueError("No UQ metrics to plot.")
14371448

1449+
# Get the indices of the modes to plot.
1450+
nd, r = self.modes.shape
1451+
if plot_modes is None or isinstance(plot_modes, int):
1452+
mode_indices = np.arange(r)
1453+
if plot_complex_pairs:
1454+
if r % 2 == 0:
1455+
mode_indices = mode_indices[::2]
1456+
else:
1457+
mode_indices = np.concatenate([(0,), mode_indices[1::2]])
1458+
if isinstance(plot_modes, int):
1459+
mode_indices = mode_indices[:plot_modes]
1460+
else:
1461+
mode_indices = plot_modes
1462+
plot_complex_pairs = True
1463+
14381464
# By default, modes_shape is the flattened space dimension.
14391465
if modes_shape is None:
14401466
modes_shape = (len(self.snapshots) // d,)
@@ -1454,21 +1480,36 @@ def plot_mode_uq(
14541480

14551481
# Collapse the results across time-delays.
14561482
if d > 1:
1457-
nd, r = modes.shape
14581483
modes = np.average(modes.reshape(d, nd // d, r), axis=0)
14591484
modes_std = np.average(modes_std.reshape(d, nd // d, r), axis=0)
14601485

14611486
# Define the subplot grid.
1462-
rows = 2 * int(np.ceil(modes.shape[-1] / cols))
1463-
plt.figure(figsize=figsize, dpi=dpi)
1487+
# Compute the number of subplot rows given the number of columns.
1488+
rows = 2 * int(np.ceil(len(mode_indices) / cols))
1489+
1490+
# Compute a grid of all subplot indices.
14641491
all_inds = np.arange(rows * cols).reshape(rows, cols)
1492+
1493+
# Get the subplot indices at which the mode averages will be plotted.
1494+
# Mode averages are plotted on the 1st, 3rd, 5th, ... rows of the plot.
14651495
avg_inds = all_inds[::2].flatten()
1496+
1497+
# Get the subplot indices at which the mode stds will be plotted.
1498+
# Mode stds are plotted on the 2nd, 4th, 6th, ... rows of the plot.
14661499
std_inds = all_inds[1::2].flatten()
14671500

1468-
for i, (mode, mode_std) in enumerate(zip(modes.T, modes_std.T)):
1501+
plt.figure(figsize=figsize, dpi=dpi)
1502+
1503+
for i, idx in enumerate(mode_indices):
1504+
mode = modes[:, idx]
1505+
mode_std = modes_std[:, idx]
1506+
14691507
# Plot the average mode.
14701508
plt.subplot(rows, cols, avg_inds[i] + 1)
1471-
plt.title(f"Mode {i + 1}")
1509+
if plot_complex_pairs:
1510+
plt.title(f"Mode {idx + 1}")
1511+
if not plot_complex_pairs:
1512+
plt.title(f"Mode {idx + 1}, {idx + 2}")
14721513
if len(modes_shape) == 1:
14731514
# Plot modes in 1-D.
14741515
plt.plot(x, mode.real, c="tab:blue")

0 commit comments

Comments
 (0)