|
18 | 18 | import numpy as np |
19 | 19 | from scipy.linalg import qr |
20 | 20 | from scipy.sparse import csr_matrix |
| 21 | +import matplotlib.pyplot as plt |
21 | 22 |
|
22 | 23 | from .dmdbase import DMDBase |
23 | 24 | from .dmdoperator import DMDOperator |
@@ -1343,7 +1344,7 @@ def fit(self, X, t): |
1343 | 1344 | self._eig_constraints, |
1344 | 1345 | self._bag_warning, |
1345 | 1346 | self._bag_maxfail, |
1346 | | - **self._varpro_opts_dict |
| 1347 | + **self._varpro_opts_dict, |
1347 | 1348 | ) |
1348 | 1349 |
|
1349 | 1350 | # Define the snapshots that will be used for fitting. |
@@ -1410,3 +1411,234 @@ def forecast(self, t): |
1410 | 1411 | ] |
1411 | 1412 | ) |
1412 | 1413 | return x |
| 1414 | + |
| 1415 | + def plot_mode_uq( |
| 1416 | + self, |
| 1417 | + *, |
| 1418 | + x=None, |
| 1419 | + y=None, |
| 1420 | + d=1, |
| 1421 | + modes_shape=None, |
| 1422 | + order="C", |
| 1423 | + cols=4, |
| 1424 | + figsize=None, |
| 1425 | + dpi=None, |
| 1426 | + plot_modes=None, |
| 1427 | + plot_conjugate_pairs=True, |
| 1428 | + ): |
| 1429 | + """ |
| 1430 | + Plot BOP-DMD modes alongside their standard deviations. |
| 1431 | +
|
| 1432 | + :param x: Points along the 1st spatial dimension where data has |
| 1433 | + been collected. |
| 1434 | + :type x: np.ndarray or iterable |
| 1435 | + :param y: Points along the 2nd spatial dimension where data has |
| 1436 | + been collected. This parameter is only applicable when the data |
| 1437 | + snapshots are 2-D, which must be indicated with `modes_shape`. |
| 1438 | + :type y: np.ndarray or iterable |
| 1439 | + :param d: Number of delays applied to the data. If `d` is greater |
| 1440 | + than 1, then each plotted mode will be the average mode taken |
| 1441 | + across all `d` delays. |
| 1442 | + :type d: int |
| 1443 | + :param modes_shape: Shape of the modes. If not provided, the shape |
| 1444 | + is assumed to be the flattened space dim of the snapshot data. |
| 1445 | + Provide as width, height dimension. |
| 1446 | + :type modes_shape: iterable |
| 1447 | + :param order: Read the elements of snapshots using this index order, |
| 1448 | + and place the elements into the reshaped array using this index |
| 1449 | + order. It has to be the same used to store the snapshots. |
| 1450 | + :type order: {"C", "F", "A"} |
| 1451 | + :param cols: Number of columns to use for the subplot grid. |
| 1452 | + :type cols: int |
| 1453 | + :param figsize: Width, height in inches. |
| 1454 | + :type figsize: iterable |
| 1455 | + :param dpi: Figure resolution. |
| 1456 | + :type dpi: int |
| 1457 | + :param plot_modes: Number of leading modes to plot, or the indices of |
| 1458 | + the modes to plot. If `None`, then all available modes are plotted. |
| 1459 | + Note that if this parameter is given as a list of indices, it will |
| 1460 | + override the `plot_complex_pair` parameter. |
| 1461 | + :type plot_modes: int or iterable |
| 1462 | + :param plot_conjugate_pairs: Whether or not to omit one of the modes |
| 1463 | + that correspond with a complex conjugate pair of eigenvalues. |
| 1464 | + :type plot_conjugate_pairs: bool |
| 1465 | + """ |
| 1466 | + if self.modes_std is None: |
| 1467 | + raise ValueError("No UQ metrics to plot.") |
| 1468 | + |
| 1469 | + # Get the indices of the modes to plot. |
| 1470 | + nd, r = self.modes.shape |
| 1471 | + if plot_modes is None or isinstance(plot_modes, int): |
| 1472 | + mode_indices = np.arange(r) |
| 1473 | + if not plot_conjugate_pairs: |
| 1474 | + if r % 2 == 0: |
| 1475 | + mode_indices = mode_indices[::2] |
| 1476 | + else: |
| 1477 | + mode_indices = np.concatenate([(0,), mode_indices[1::2]]) |
| 1478 | + if isinstance(plot_modes, int): |
| 1479 | + mode_indices = mode_indices[:plot_modes] |
| 1480 | + else: |
| 1481 | + mode_indices = plot_modes |
| 1482 | + plot_conjugate_pairs = True |
| 1483 | + |
| 1484 | + # By default, modes_shape is the flattened space dimension. |
| 1485 | + if modes_shape is None: |
| 1486 | + modes_shape = (len(self.snapshots) // d,) |
| 1487 | + |
| 1488 | + # Order the modes and their standard deviations. |
| 1489 | + mode_order = np.argsort(-np.abs(self.amplitudes)) |
| 1490 | + modes = self.modes[:, mode_order] |
| 1491 | + modes_std = self.modes_std[:, mode_order] |
| 1492 | + |
| 1493 | + # Build the spatial grid for the mode plots. |
| 1494 | + if x is None: |
| 1495 | + x = np.arange(modes_shape[0]) |
| 1496 | + if len(modes_shape) == 2: |
| 1497 | + if y is None: |
| 1498 | + y = np.arange(modes_shape[1]) |
| 1499 | + ygrid, xgrid = np.meshgrid(y, x) |
| 1500 | + |
| 1501 | + # Collapse the results across time-delays. |
| 1502 | + if d > 1: |
| 1503 | + modes = np.average(modes.reshape(d, nd // d, r), axis=0) |
| 1504 | + modes_std = np.average(modes_std.reshape(d, nd // d, r), axis=0) |
| 1505 | + |
| 1506 | + # Define the subplot grid. |
| 1507 | + # Compute the number of subplot rows given the number of columns. |
| 1508 | + rows = 2 * int(np.ceil(len(mode_indices) / cols)) |
| 1509 | + |
| 1510 | + # Compute a grid of all subplot indices. |
| 1511 | + all_inds = np.arange(rows * cols).reshape(rows, cols) |
| 1512 | + |
| 1513 | + # Get the subplot indices at which the mode averages will be plotted. |
| 1514 | + # Mode averages are plotted on the 1st, 3rd, 5th, ... rows of the plot. |
| 1515 | + avg_inds = all_inds[::2].flatten() |
| 1516 | + |
| 1517 | + # Get the subplot indices at which the mode stds will be plotted. |
| 1518 | + # Mode stds are plotted on the 2nd, 4th, 6th, ... rows of the plot. |
| 1519 | + std_inds = all_inds[1::2].flatten() |
| 1520 | + |
| 1521 | + plt.figure(figsize=figsize, dpi=dpi) |
| 1522 | + |
| 1523 | + for i, idx in enumerate(mode_indices): |
| 1524 | + mode = modes[:, idx] |
| 1525 | + mode_std = modes_std[:, idx] |
| 1526 | + |
| 1527 | + # Plot the average mode. |
| 1528 | + plt.subplot(rows, cols, avg_inds[i] + 1) |
| 1529 | + if plot_conjugate_pairs or (r % 2 == 1 and i == 0): |
| 1530 | + plt.title(f"Mode {idx + 1}") |
| 1531 | + else: |
| 1532 | + plt.title(f"Modes {idx + 1},{idx + 2}") |
| 1533 | + if len(modes_shape) == 1: |
| 1534 | + # Plot modes in 1-D. |
| 1535 | + plt.plot(x, mode.real, c="tab:blue") |
| 1536 | + else: |
| 1537 | + # Plot modes in 2-D. |
| 1538 | + plt.pcolormesh( |
| 1539 | + xgrid, |
| 1540 | + ygrid, |
| 1541 | + mode.reshape(*modes_shape, order=order).real, |
| 1542 | + cmap="viridis", |
| 1543 | + ) |
| 1544 | + plt.colorbar() |
| 1545 | + |
| 1546 | + # Plot the mode standard deviation. |
| 1547 | + plt.subplot(rows, cols, std_inds[i] + 1) |
| 1548 | + plt.title("Mode Standard Deviation") |
| 1549 | + if len(modes_shape) == 1: |
| 1550 | + # Plot modes in 1-D. |
| 1551 | + plt.plot(x, mode_std, c="tab:red") |
| 1552 | + else: |
| 1553 | + # Plot modes in 2-D. |
| 1554 | + plt.pcolormesh( |
| 1555 | + xgrid, |
| 1556 | + ygrid, |
| 1557 | + mode_std.reshape(*modes_shape, order=order), |
| 1558 | + cmap="inferno", |
| 1559 | + ) |
| 1560 | + plt.colorbar() |
| 1561 | + |
| 1562 | + plt.suptitle("DMD Modes") |
| 1563 | + plt.tight_layout() |
| 1564 | + plt.show() |
| 1565 | + |
| 1566 | + def plot_eig_uq( |
| 1567 | + self, |
| 1568 | + eigs_true=None, |
| 1569 | + xlim=None, |
| 1570 | + ylim=None, |
| 1571 | + figsize=None, |
| 1572 | + dpi=None, |
| 1573 | + flip_axes=False, |
| 1574 | + draw_axes=False, |
| 1575 | + ): |
| 1576 | + """ |
| 1577 | + Plot BOP-DMD eigenvalues against 1 and 2 standard deviations. |
| 1578 | +
|
| 1579 | + :param eigs_true: True continuous-time eigenvalues, if known. |
| 1580 | + :type eigs_true: np.ndarray or iterable |
| 1581 | + :param xlim: Desired limits for the x-axis. |
| 1582 | + :type xlim: iterable |
| 1583 | + :param ylim: Desired limits for the y-axis. |
| 1584 | + :type ylim: iterable |
| 1585 | + :param figsize: Width, height in inches. |
| 1586 | + :type figsize: iterable |
| 1587 | + :param dpi: Figure resolution. |
| 1588 | + :type dpi: int |
| 1589 | + :param flip_axes: Whether or not to swap the real and imaginary axes |
| 1590 | + on the eigenvalue plot. If `True`, the real axis will be vertical |
| 1591 | + and the imaginary axis will be horizontal. |
| 1592 | + :type flip_axes: bool |
| 1593 | + :param draw_axes: Whether or not to draw the real and imaginary axes. |
| 1594 | + :type draw_axes: bool |
| 1595 | + """ |
| 1596 | + |
| 1597 | + if self.eigenvalues_std is None: |
| 1598 | + raise ValueError("No UQ metrics to plot.") |
| 1599 | + |
| 1600 | + if eigs_true is not None: |
| 1601 | + eigs_true = np.array(eigs_true) |
| 1602 | + |
| 1603 | + fig, ax = plt.subplots(figsize=figsize, dpi=dpi) |
| 1604 | + plt.title("DMD Eigenvalues") |
| 1605 | + |
| 1606 | + if draw_axes: |
| 1607 | + ax.axhline(y=0, c="k", lw=1) |
| 1608 | + ax.axvline(x=0, c="k", lw=1) |
| 1609 | + |
| 1610 | + if flip_axes: |
| 1611 | + eigs = self.eigs.imag + 1j * self.eigs.real |
| 1612 | + plt.xlabel("$Im(\omega)$") |
| 1613 | + plt.ylabel("$Re(\omega)$") |
| 1614 | + |
| 1615 | + if eigs_true is not None: |
| 1616 | + eigs_true = eigs_true.imag + 1j * eigs_true.real |
| 1617 | + |
| 1618 | + else: |
| 1619 | + eigs = self.eigs |
| 1620 | + plt.xlabel("$Re(\omega)$") |
| 1621 | + plt.ylabel("$Im(\omega)$") |
| 1622 | + |
| 1623 | + for e, std in zip(eigs, self.eigenvalues_std): |
| 1624 | + # Plot 2 standard deviations. |
| 1625 | + c_1 = plt.Circle((e.real, e.imag), 2 * std, color="b", alpha=0.2) |
| 1626 | + ax.add_patch(c_1) |
| 1627 | + # Plot 1 standard deviation. |
| 1628 | + c_2 = plt.Circle((e.real, e.imag), std, color="b", alpha=0.5) |
| 1629 | + ax.add_patch(c_2) |
| 1630 | + |
| 1631 | + # Plot the average eigenvalues. |
| 1632 | + ax.plot(eigs.real, eigs.imag, "o", c="b", label="BOP-DMD") |
| 1633 | + |
| 1634 | + # Plot the true eigenvalues if given. |
| 1635 | + if eigs_true is not None: |
| 1636 | + ax.plot(eigs_true.real, eigs_true.imag, "x", c="k", label="Truth") |
| 1637 | + |
| 1638 | + if xlim is not None: |
| 1639 | + ax.set_xlim(xlim) |
| 1640 | + if ylim is not None: |
| 1641 | + ax.set_ylim(ylim) |
| 1642 | + |
| 1643 | + plt.legend() |
| 1644 | + plt.show() |
0 commit comments