Skip to content

Commit 47a860c

Browse files
authored
Merge pull request PyDMD#506 from sichinaga/bopdmd-uq-tools
BOP-DMD UQ Tools
2 parents f2c5385 + 71026cf commit 47a860c

File tree

1 file changed

+233
-1
lines changed

1 file changed

+233
-1
lines changed

pydmd/bopdmd.py

Lines changed: 233 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
from scipy.linalg import qr
2020
from scipy.sparse import csr_matrix
21+
import matplotlib.pyplot as plt
2122

2223
from .dmdbase import DMDBase
2324
from .dmdoperator import DMDOperator
@@ -1343,7 +1344,7 @@ def fit(self, X, t):
13431344
self._eig_constraints,
13441345
self._bag_warning,
13451346
self._bag_maxfail,
1346-
**self._varpro_opts_dict
1347+
**self._varpro_opts_dict,
13471348
)
13481349

13491350
# Define the snapshots that will be used for fitting.
@@ -1410,3 +1411,234 @@ def forecast(self, t):
14101411
]
14111412
)
14121413
return x
1414+
1415+
def plot_mode_uq(
1416+
self,
1417+
*,
1418+
x=None,
1419+
y=None,
1420+
d=1,
1421+
modes_shape=None,
1422+
order="C",
1423+
cols=4,
1424+
figsize=None,
1425+
dpi=None,
1426+
plot_modes=None,
1427+
plot_conjugate_pairs=True,
1428+
):
1429+
"""
1430+
Plot BOP-DMD modes alongside their standard deviations.
1431+
1432+
:param x: Points along the 1st spatial dimension where data has
1433+
been collected.
1434+
:type x: np.ndarray or iterable
1435+
:param y: Points along the 2nd spatial dimension where data has
1436+
been collected. This parameter is only applicable when the data
1437+
snapshots are 2-D, which must be indicated with `modes_shape`.
1438+
:type y: np.ndarray or iterable
1439+
:param d: Number of delays applied to the data. If `d` is greater
1440+
than 1, then each plotted mode will be the average mode taken
1441+
across all `d` delays.
1442+
:type d: int
1443+
:param modes_shape: Shape of the modes. If not provided, the shape
1444+
is assumed to be the flattened space dim of the snapshot data.
1445+
Provide as width, height dimension.
1446+
:type modes_shape: iterable
1447+
:param order: Read the elements of snapshots using this index order,
1448+
and place the elements into the reshaped array using this index
1449+
order. It has to be the same used to store the snapshots.
1450+
:type order: {"C", "F", "A"}
1451+
:param cols: Number of columns to use for the subplot grid.
1452+
:type cols: int
1453+
:param figsize: Width, height in inches.
1454+
:type figsize: iterable
1455+
:param dpi: Figure resolution.
1456+
:type dpi: int
1457+
:param plot_modes: Number of leading modes to plot, or the indices of
1458+
the modes to plot. If `None`, then all available modes are plotted.
1459+
Note that if this parameter is given as a list of indices, it will
1460+
override the `plot_complex_pair` parameter.
1461+
:type plot_modes: int or iterable
1462+
:param plot_conjugate_pairs: Whether or not to omit one of the modes
1463+
that correspond with a complex conjugate pair of eigenvalues.
1464+
:type plot_conjugate_pairs: bool
1465+
"""
1466+
if self.modes_std is None:
1467+
raise ValueError("No UQ metrics to plot.")
1468+
1469+
# Get the indices of the modes to plot.
1470+
nd, r = self.modes.shape
1471+
if plot_modes is None or isinstance(plot_modes, int):
1472+
mode_indices = np.arange(r)
1473+
if not plot_conjugate_pairs:
1474+
if r % 2 == 0:
1475+
mode_indices = mode_indices[::2]
1476+
else:
1477+
mode_indices = np.concatenate([(0,), mode_indices[1::2]])
1478+
if isinstance(plot_modes, int):
1479+
mode_indices = mode_indices[:plot_modes]
1480+
else:
1481+
mode_indices = plot_modes
1482+
plot_conjugate_pairs = True
1483+
1484+
# By default, modes_shape is the flattened space dimension.
1485+
if modes_shape is None:
1486+
modes_shape = (len(self.snapshots) // d,)
1487+
1488+
# Order the modes and their standard deviations.
1489+
mode_order = np.argsort(-np.abs(self.amplitudes))
1490+
modes = self.modes[:, mode_order]
1491+
modes_std = self.modes_std[:, mode_order]
1492+
1493+
# Build the spatial grid for the mode plots.
1494+
if x is None:
1495+
x = np.arange(modes_shape[0])
1496+
if len(modes_shape) == 2:
1497+
if y is None:
1498+
y = np.arange(modes_shape[1])
1499+
ygrid, xgrid = np.meshgrid(y, x)
1500+
1501+
# Collapse the results across time-delays.
1502+
if d > 1:
1503+
modes = np.average(modes.reshape(d, nd // d, r), axis=0)
1504+
modes_std = np.average(modes_std.reshape(d, nd // d, r), axis=0)
1505+
1506+
# Define the subplot grid.
1507+
# Compute the number of subplot rows given the number of columns.
1508+
rows = 2 * int(np.ceil(len(mode_indices) / cols))
1509+
1510+
# Compute a grid of all subplot indices.
1511+
all_inds = np.arange(rows * cols).reshape(rows, cols)
1512+
1513+
# Get the subplot indices at which the mode averages will be plotted.
1514+
# Mode averages are plotted on the 1st, 3rd, 5th, ... rows of the plot.
1515+
avg_inds = all_inds[::2].flatten()
1516+
1517+
# Get the subplot indices at which the mode stds will be plotted.
1518+
# Mode stds are plotted on the 2nd, 4th, 6th, ... rows of the plot.
1519+
std_inds = all_inds[1::2].flatten()
1520+
1521+
plt.figure(figsize=figsize, dpi=dpi)
1522+
1523+
for i, idx in enumerate(mode_indices):
1524+
mode = modes[:, idx]
1525+
mode_std = modes_std[:, idx]
1526+
1527+
# Plot the average mode.
1528+
plt.subplot(rows, cols, avg_inds[i] + 1)
1529+
if plot_conjugate_pairs or (r % 2 == 1 and i == 0):
1530+
plt.title(f"Mode {idx + 1}")
1531+
else:
1532+
plt.title(f"Modes {idx + 1},{idx + 2}")
1533+
if len(modes_shape) == 1:
1534+
# Plot modes in 1-D.
1535+
plt.plot(x, mode.real, c="tab:blue")
1536+
else:
1537+
# Plot modes in 2-D.
1538+
plt.pcolormesh(
1539+
xgrid,
1540+
ygrid,
1541+
mode.reshape(*modes_shape, order=order).real,
1542+
cmap="viridis",
1543+
)
1544+
plt.colorbar()
1545+
1546+
# Plot the mode standard deviation.
1547+
plt.subplot(rows, cols, std_inds[i] + 1)
1548+
plt.title("Mode Standard Deviation")
1549+
if len(modes_shape) == 1:
1550+
# Plot modes in 1-D.
1551+
plt.plot(x, mode_std, c="tab:red")
1552+
else:
1553+
# Plot modes in 2-D.
1554+
plt.pcolormesh(
1555+
xgrid,
1556+
ygrid,
1557+
mode_std.reshape(*modes_shape, order=order),
1558+
cmap="inferno",
1559+
)
1560+
plt.colorbar()
1561+
1562+
plt.suptitle("DMD Modes")
1563+
plt.tight_layout()
1564+
plt.show()
1565+
1566+
def plot_eig_uq(
1567+
self,
1568+
eigs_true=None,
1569+
xlim=None,
1570+
ylim=None,
1571+
figsize=None,
1572+
dpi=None,
1573+
flip_axes=False,
1574+
draw_axes=False,
1575+
):
1576+
"""
1577+
Plot BOP-DMD eigenvalues against 1 and 2 standard deviations.
1578+
1579+
:param eigs_true: True continuous-time eigenvalues, if known.
1580+
:type eigs_true: np.ndarray or iterable
1581+
:param xlim: Desired limits for the x-axis.
1582+
:type xlim: iterable
1583+
:param ylim: Desired limits for the y-axis.
1584+
:type ylim: iterable
1585+
:param figsize: Width, height in inches.
1586+
:type figsize: iterable
1587+
:param dpi: Figure resolution.
1588+
:type dpi: int
1589+
:param flip_axes: Whether or not to swap the real and imaginary axes
1590+
on the eigenvalue plot. If `True`, the real axis will be vertical
1591+
and the imaginary axis will be horizontal.
1592+
:type flip_axes: bool
1593+
:param draw_axes: Whether or not to draw the real and imaginary axes.
1594+
:type draw_axes: bool
1595+
"""
1596+
1597+
if self.eigenvalues_std is None:
1598+
raise ValueError("No UQ metrics to plot.")
1599+
1600+
if eigs_true is not None:
1601+
eigs_true = np.array(eigs_true)
1602+
1603+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
1604+
plt.title("DMD Eigenvalues")
1605+
1606+
if draw_axes:
1607+
ax.axhline(y=0, c="k", lw=1)
1608+
ax.axvline(x=0, c="k", lw=1)
1609+
1610+
if flip_axes:
1611+
eigs = self.eigs.imag + 1j * self.eigs.real
1612+
plt.xlabel("$Im(\omega)$")
1613+
plt.ylabel("$Re(\omega)$")
1614+
1615+
if eigs_true is not None:
1616+
eigs_true = eigs_true.imag + 1j * eigs_true.real
1617+
1618+
else:
1619+
eigs = self.eigs
1620+
plt.xlabel("$Re(\omega)$")
1621+
plt.ylabel("$Im(\omega)$")
1622+
1623+
for e, std in zip(eigs, self.eigenvalues_std):
1624+
# Plot 2 standard deviations.
1625+
c_1 = plt.Circle((e.real, e.imag), 2 * std, color="b", alpha=0.2)
1626+
ax.add_patch(c_1)
1627+
# Plot 1 standard deviation.
1628+
c_2 = plt.Circle((e.real, e.imag), std, color="b", alpha=0.5)
1629+
ax.add_patch(c_2)
1630+
1631+
# Plot the average eigenvalues.
1632+
ax.plot(eigs.real, eigs.imag, "o", c="b", label="BOP-DMD")
1633+
1634+
# Plot the true eigenvalues if given.
1635+
if eigs_true is not None:
1636+
ax.plot(eigs_true.real, eigs_true.imag, "x", c="k", label="Truth")
1637+
1638+
if xlim is not None:
1639+
ax.set_xlim(xlim)
1640+
if ylim is not None:
1641+
ax.set_ylim(ylim)
1642+
1643+
plt.legend()
1644+
plt.show()

0 commit comments

Comments
 (0)