Skip to content

Commit 5fd44b7

Browse files
committed
Added UQ tools
1 parent 7f4aa77 commit 5fd44b7

File tree

1 file changed

+166
-1
lines changed

1 file changed

+166
-1
lines changed

pydmd/bopdmd.py

Lines changed: 166 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import numpy as np
1919
from scipy.linalg import qr
2020
from scipy.sparse import csr_matrix
21+
import matplotlib.pyplot as plt
2122

2223
from .dmdbase import DMDBase
2324
from .dmdoperator import DMDOperator
@@ -1323,7 +1324,7 @@ def fit(self, X, t):
13231324
self._eig_constraints,
13241325
self._bag_warning,
13251326
self._bag_maxfail,
1326-
**self._varpro_opts_dict
1327+
**self._varpro_opts_dict,
13271328
)
13281329

13291330
# Define the snapshots that will be used for fitting.
@@ -1390,3 +1391,167 @@ def forecast(self, t):
13901391
]
13911392
)
13921393
return x
1394+
1395+
def plot_mode_uq(
1396+
self,
1397+
x=None,
1398+
y=None,
1399+
d=1,
1400+
modes_shape=None,
1401+
order="C",
1402+
cols=4,
1403+
figsize=None,
1404+
dpi=None,
1405+
):
1406+
"""
1407+
Plot BOP-DMD modes alongside their standard deviations.
1408+
1409+
:param x: Points along the 1st spatial dimension where data has
1410+
been collected.
1411+
:type x: np.ndarray or iterable
1412+
:param y: Points along the 2nd spatial dimension where data has
1413+
been collected. This parameter is only applicable when the data
1414+
snapshots are 2-D, which must be indicated with `modes_shape`.
1415+
:type y: np.ndarray or iterable
1416+
:param d: Number of delays applied to the data. If `d` is greater
1417+
than 1, then each plotted mode will be the average mode taken
1418+
across all `d` delays.
1419+
:type d: int
1420+
:param modes_shape: Shape of the modes. If not provided, the shape
1421+
is assumed to be the flattened space dim of the snapshot data.
1422+
Provide as width, height dimension.
1423+
:type modes_shape: iterable
1424+
:param order: Read the elements of snapshots using this index order,
1425+
and place the elements into the reshaped array using this index
1426+
order. It has to be the same used to store the snapshots.
1427+
:type order: {"C", "F", "A"}
1428+
:param cols: Number of columns to use for the subplot grid.
1429+
:type cols: int
1430+
:param figsize: Width, height in inches.
1431+
:type figsize: iterable
1432+
:param dpi: Figure resolution.
1433+
:type dpi: int
1434+
"""
1435+
if self.modes_std is None:
1436+
raise ValueError("No UQ metrics to plot.")
1437+
1438+
# By default, modes_shape is the flattened space dimension.
1439+
if modes_shape is None:
1440+
modes_shape = (len(self.snapshots) // d,)
1441+
1442+
# Order the modes and their standard deviations.
1443+
mode_order = np.argsort(-np.abs(self.amplitudes))
1444+
modes = self.modes[:, mode_order]
1445+
modes_std = self.modes_std[:, mode_order]
1446+
1447+
# Build the spatial grid for the mode plots.
1448+
if x is None:
1449+
x = np.arange(modes_shape[0])
1450+
if len(modes_shape) == 2:
1451+
if y is None:
1452+
y = np.arange(modes_shape[1])
1453+
ygrid, xgrid = np.meshgrid(y, x)
1454+
1455+
# Collapse the results across time-delays.
1456+
if d > 1:
1457+
nd, r = modes.shape
1458+
modes = np.average(modes.reshape(d, nd // d, r), axis=0)
1459+
modes_std = np.average(modes_std.reshape(d, nd // d, r), axis=0)
1460+
1461+
rows = int(np.ceil(modes.shape[-1] / cols))
1462+
fig, axes = plt.subplots(rows, cols, figsize=figsize, dpi=dpi)
1463+
avg_axes = [ax for axes_list in axes[::2] for ax in axes_list]
1464+
std_axes = [ax for axes_list in axes[1::2] for ax in axes_list]
1465+
1466+
for i, (ax_avg, ax_std, mode, mode_std) in enumerate(
1467+
zip(avg_axes, std_axes, modes.T, modes_std.T)
1468+
):
1469+
ax_avg.set_title(f"Mode {i + 1}")
1470+
ax_std.set_title("Mode Standard Deviation")
1471+
1472+
if len(modes_shape) == 1:
1473+
# Plot modes in 1-D.
1474+
ax_avg.plot(x, mode.real, c="tab:blue")
1475+
ax_std.plot(x, mode_std, c="tab:red")
1476+
else:
1477+
# Plot modes in 2-D.
1478+
im_avg = ax_avg.pcolormesh(
1479+
xgrid,
1480+
ygrid,
1481+
mode.reshape(*modes_shape, order=order).real,
1482+
cmap="viridis",
1483+
)
1484+
im_std = ax_std.pcolormesh(
1485+
xgrid,
1486+
ygrid,
1487+
mode_std.reshape(*modes_shape, order=order),
1488+
cmap="inferno",
1489+
)
1490+
fig.colorbar(im_avg, ax=ax_avg)
1491+
fig.colorbar(im_std, ax=ax_std)
1492+
1493+
plt.suptitle("DMD Modes")
1494+
plt.tight_layout()
1495+
plt.show()
1496+
1497+
def plot_eig_uq(
1498+
self,
1499+
eigs_true=None,
1500+
figsize=None,
1501+
dpi=None,
1502+
flip_axes=False,
1503+
):
1504+
"""
1505+
Plot BOP-DMD eigenvalues against 1 and 2 standard deviations.
1506+
1507+
:param eigs_true: True continuous-time eigenvalues, if known.
1508+
:type eigs_true: np.ndarray
1509+
:param figsize: Width, height in inches.
1510+
:type figsize: iterable
1511+
:param dpi: Figure resolution.
1512+
:type dpi: int
1513+
:param flip_axes: Whether or not to swap the real and imaginary axes
1514+
on the eigenvalue plot. If `True`, the real axis will be vertical
1515+
and the imaginary axis will be horizontal.
1516+
:type flip_axes: bool
1517+
"""
1518+
1519+
if self.eigenvalues_std is None:
1520+
raise ValueError("No UQ metrics to plot.")
1521+
1522+
fig, ax = plt.subplots(figsize=figsize, dpi=dpi)
1523+
plt.title("DMD Eigenvalues")
1524+
1525+
if flip_axes:
1526+
eigs = self.eigs.imag + 1j * self.eigs.real
1527+
plt.xlabel("$Im(\omega)$")
1528+
plt.ylabel("$Re(\omega)$")
1529+
else:
1530+
eigs = self.eigs
1531+
plt.xlabel("$Re(\omega)$")
1532+
plt.ylabel("$Im(\omega)$")
1533+
1534+
for e, std in zip(self.eigs, self.eigenvalues_std):
1535+
# Plot 2 standard deviations.
1536+
c_1 = plt.Circle((e.real, e.imag), 2 * std, color="b", alpha=0.2)
1537+
ax.add_patch(c_1)
1538+
# Plot 1 standard deviation.
1539+
c_2 = plt.Circle((e.real, e.imag), std, color="b", alpha=0.5)
1540+
ax.add_patch(c_2)
1541+
1542+
# Plot the average eigenvalues.
1543+
ax.plot(eigs.real, eigs.imag, "o", c="b", label="BOP-DMD")
1544+
1545+
# Plot the true eigenvalues if given.
1546+
if eigs_true is not None:
1547+
if flip_axes:
1548+
ax.plot(
1549+
eigs_true.imag, eigs_true.real, "x", c="k", label="Truth"
1550+
)
1551+
else:
1552+
ax.plot(
1553+
eigs_true.real, eigs_true.imag, "x", c="k", label="Truth"
1554+
)
1555+
1556+
plt.legend()
1557+
plt.show()

0 commit comments

Comments
 (0)