Skip to content

Commit 6651550

Browse files
committed
This function got really goofy
1 parent 7cc01e4 commit 6651550

File tree

1 file changed

+111
-110
lines changed

1 file changed

+111
-110
lines changed

pydmd/plotter.py

Lines changed: 111 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import matplotlib.pyplot as plt
99
import numpy as np
1010
from mpl_toolkits.axes_grid1 import make_axes_locatable
11+
from matplotlib.patches import Patch
1112

1213
from pydmd import MrDMD
1314

@@ -535,49 +536,50 @@ def plot_snapshots_2D(
535536
def plot_summary(
536537
dmd,
537538
*,
539+
x=None,
538540
t=None,
539541
d=1,
540542
continuous=False,
541543
snapshots_shape=None,
542-
index_modes=None,
544+
index_modes=(0, 1, 2),
543545
filename=None,
544546
order="C",
545547
figsize=(12, 8),
546548
dpi=200,
547549
tight_layout_kwargs=None,
548550
main_colors=("r", "b", "g"),
549-
mode_color="k",
550-
mode_cmap="bwr",
551-
sval_color="tab:orange",
552-
dynamics_color="tab:blue",
551+
imshow_kwargs=None,
553552
sval_ms=8,
554-
max_eig_ms=10,
553+
max_eig_ms=12,
555554
max_sval_plot=50,
556555
title_fontsize=14,
557556
label_fontsize=12,
558557
plot_semilogy=False,
559-
remove_cmap_ticks=False,
560558
):
561559
"""
562560
Generate a 3 x 3 summarizing plot that contains the following components:
563561
- the singular value spectrum of the data
564562
- the discrete-time and continuous-time DMD eigenvalues
565-
- the three DMD modes specified by the `index_modes` parameter
566-
- the dynamics corresponding with each plotted mode
567-
Eigenvalues, modes, and dynamics are ordered according to the magnitude of
568-
their corresponding amplitude value. Singular values and eigenvalues that
569-
are associated with plotted modes and dynamics are also highlighted.
570-
571-
:param dmd: DMD instance.
563+
- the DMD modes specified by the `index_modes` parameter
564+
- the time dynamics that correspond with each plotted mode
565+
The number of singular values used for the DMD fit are highlighted.
566+
All eigenvalues, modes, and dynamics are sorted according to the magnitude
567+
of their corresponding amplitude value, i.e. their significance in the fit.
568+
Correspondence between eigenvalues, modes, and dynamics is indicated via
569+
color coordination.
570+
571+
:param dmd: fitted DMD instance.
572572
:type dmd: pydmd.DMDBase
573-
:param t: the input time vector or uniform time-step between snapshots.
574-
Note that the time-step must be accurate in order to visualize accurate
575-
discrete and continuous-time eigenvalues, as well as accurate times of
576-
the dynamics. For non-`BOPDMD` models, times of data collection must be
577-
uniformly-spaced, and if not provided, TimeDict information stored in
578-
the provided DMD instance is used instead. This parameter is ignored if
579-
an instance of `BOPDMD` is provided.
580-
:type t: {numpy.ndarray, list} or {int, float}
573+
:param x: The points in space where the data has been collected. Note that
574+
this parameter is currently only used for plotting modes that are 1-D.
575+
:type x: np.ndarray or iterable
576+
:param t: The times of data collection, or the time-step between snapshots.
577+
Note that time information must be accurate in order to accurately
578+
visualize eigenvalues and times of the dynamics. For non-`BOPDMD`
579+
models, the entries of t are assumed to be uniformly-spaced, and if
580+
not provided, TimeDict information is used. This parameter is ignored
581+
if an instance of `BOPDMD` is provided.
582+
:type t: {numpy.ndarray, iterable} or {int, float}
581583
:param d: Number of delays applied to the data passed to the DMD instance.
582584
If `d` is greater than 1, then each plotted mode will be the average
583585
mode taken across all `d` delays.
@@ -591,10 +593,11 @@ def plot_summary(
591593
:param snapshots_shape: Shape of the snapshots. If not provided, the shape
592594
of the snapshots and modes is assumed to be the flattened space dim of
593595
the snapshot data.
594-
:type snapshots_shape: tuple(int, int)
595-
:param index_modes: A list of the indices of the modes to plot. By default,
596-
the first three leading modes are plotted.
597-
:type index_modes: list
596+
:type snapshots_shape: iterable
597+
:param index_modes: Indices of the modes to plot after they have been
598+
sorted based on significance. At most three may be provided.
599+
By default, the first three leading modes are plotted.
600+
:type index_modes: iterable
598601
:param filename: If specified, the plot is saved at `filename`.
599602
:type filename: str
600603
:param order: Read the elements of snapshots using this index order,
@@ -610,27 +613,25 @@ def plot_summary(
610613
order if a is Fortran contiguous in memory, C-like order otherwise.
611614
"C" is used by default.
612615
:type order: {"C", "F", "A"}
613-
:param figsize: Tuple in inches defining the figure size.
614-
:type figsize: tuple(int, int)
616+
:param figsize: Width, height in inches.
617+
:type figsize: iterable
615618
:param dpi: Figure resolution.
616619
:type dpi: int
617620
:param tight_layout_kwargs: Optional dictionary of
618-
`matplotlib.pyplot.tight_layout()` parameters.
621+
`matplotlib.pyplot.tight_layout` parameters.
619622
:type tight_layout_kwargs: dict
620-
:param main_colors: Tuple of strings defining the colors used to denote
621-
eigenvalue, mode, dynamics associations.
622-
:type main_colors: tuple(str, str, str)
623-
:param mode_color: Color used to plot the modes, if modes are 1D.
624-
:type mode_color: str
625-
:param mode_cmap: Colormap used to plot the modes, if modes are 2D.
626-
:type mode_cmap: str
627-
:param dynamics_color: Color used to plot the dynamics.
628-
:type dynamics_color: str
623+
:param main_colors: Strings defining the colors used to denote eigenvalue,
624+
mode, dynamics associations.
625+
:type main_colors: iterable
626+
:param imshow_kwargs: Optional dictionary of `matplotlib.pyplot.imshow`
627+
parameters. Use this dictionary to re-define the parameters of 2-D
628+
mode plots.
629+
:type imshow_kwargs: dict
629630
:param sval_ms: Marker size of all singular values.
630631
:type sval_ms: int
631632
:param max_eig_ms: Marker size of the most prominent eigenvalue. The marker
632633
sizes of all other eigenvalues are then scaled according to eigenvalue
633-
prominence.
634+
significance.
634635
:type max_eig_ms: int
635636
:param max_sval_plot: Maximum number of singular values to plot.
636637
:type max_sval_plot: int
@@ -641,9 +642,6 @@ def plot_summary(
641642
:param plot_semilogy: Whether or not to plot the singular values on a
642643
semilogy plot. If `True`, a semilogy plot is used.
643644
:type plot_semilogy: bool
644-
:param remove_cmap_ticks: Whether or not to include the ticks on 2D mode
645-
plots. If `True`, ticks are removed from all 2D mode plots.
646-
:type remove_cmap_ticks: bool
647645
"""
648646

649647
# This plotting method is inappropriate for plotting HAVOK results.
@@ -660,28 +658,27 @@ def plot_summary(
660658
# By default, snapshots_shape is the flattened space dimension.
661659
if snapshots_shape is None:
662660
snapshots_shape = (len(dmd.snapshots),)
663-
# Only 2D tuples are admissible for snapshots_shape.
664-
elif not isinstance(snapshots_shape, tuple) or len(snapshots_shape) != 2:
665-
raise ValueError("snapshots_shape must be None or a 2D tuple.")
661+
# If provided, snapshots_shape must contain 2 entires.
662+
elif len(snapshots_shape) != 2:
663+
raise ValueError("snapshots_shape must be None or 2D.")
666664

667665
# Get the actual rank used for the DMD fit.
668666
rank = len(dmd.eigs)
669667

670668
# Override index_modes if there are less than 3 modes available.
671669
if rank < 3:
672670
warnings.warn(
673-
"Provided dmd model has less than 3 modes."
674-
"Plotting all available modes."
671+
"Provided DMD model has less than 3 modes."
672+
"Plotting all available modes..."
675673
)
676-
index_modes = list(range(rank))
677-
# By default, we plot the 3 leading modes and their dynamics.
678-
elif index_modes is None:
679-
index_modes = list(range(3))
680-
# index_modes was provided - check its type and its length.
681-
elif not isinstance(index_modes, list) or len(index_modes) > 3:
682-
raise ValueError("index_modes must be a list of length at most 3.")
674+
index_modes = np.arange(rank)
675+
676+
# Check the length of index_modes.
677+
if len(index_modes) > 3:
678+
raise ValueError("index_modes must have a length of at most 3.")
679+
683680
# Indices cannot go past the total number of available or plottable modes.
684-
elif np.any(np.array(index_modes) >= min(rank, max_sval_plot)):
681+
if np.any(np.array(index_modes) >= min(rank, max_sval_plot)):
685682
raise ValueError(
686683
f"Cannot view past mode {min(rank, max_sval_plot)}."
687684
)
@@ -694,6 +691,8 @@ def plot_summary(
694691
lead_amplitudes = np.abs(dmd.amplitudes[mode_order])
695692

696693
# Get time information for eigenvalue conversions.
694+
# The decisions that we make here depend on if we're dealing
695+
# with a BOPDMD model or any other type of DMD model.
697696
if isinstance(dmd, BOPDMD) or (
698697
isinstance(dmd, PrePostProcessingDMD)
699698
and isinstance(dmd.pre_post_processed_dmd, BOPDMD)
@@ -717,7 +716,7 @@ def plot_summary(
717716
if isinstance(t, (int, float)):
718717
time = np.arange(dmd.snapshots.shape[-1]) * t
719718
dt = t
720-
elif isinstance(t, (np.ndarray, list)):
719+
elif t is not None:
721720
# Note: assumes uniform spacing in the provided time vector.
722721
time = np.squeeze(np.array(t))
723722
dt = time[1] - time[0]
@@ -759,101 +758,103 @@ def plot_summary(
759758
s = np.linalg.svd(snp, full_matrices=False, compute_uv=False)
760759
# Compute the percent of data variance captured by each singular value.
761760
s_var = s * (100 / np.sum(s))
761+
s_var = s_var[:max_sval_plot]
762762

763763
# Generate the summarizing plot.
764764
fig, (eig_axes, mode_axes, dynamics_axes) = plt.subplots(
765765
3, 3, figsize=figsize, dpi=dpi
766766
)
767767

768768
# PLOT 1: Plot the singular value spectrum.
769-
s_var_plot = s_var[:max_sval_plot]
770769
eig_axes[0].set_title("Singular Values", fontsize=title_fontsize)
771770
eig_axes[0].set_ylabel("% variance", fontsize=label_fontsize)
772-
s_t = np.arange(len(s_var_plot)) + 1
773-
eig_axes[0].plot(s_t, s_var_plot, "o", c="gray", ms=sval_ms, mec="k")
771+
s_t = np.arange(len(s_var)) + 1
772+
eig_axes[0].plot(s_t, s_var, "o", c="gray", ms=sval_ms, mec="k")
774773
eig_axes[0].plot(
775-
s_t[:rank], s_var_plot[:rank], "o", c=sval_color, ms=sval_ms, mec="k"
774+
s_t[:rank], s_var[:rank], "o", c="tab:orange", ms=sval_ms, mec="k"
775+
)
776+
eig_axes[0].legend(
777+
handles=[Patch(facecolor="tab:orange", label="Rank of fit")]
776778
)
777-
778-
# for i, idx in enumerate(index_modes):
779-
# eig_axes[0].plot(
780-
# s_t[idx],
781-
# s_var_plot[idx],
782-
# "o",
783-
# c=main_colors[i],
784-
# ms=sval_ms,
785-
# mec="k",
786-
# )
787-
788779
if plot_semilogy:
789780
eig_axes[0].semilogy()
790781

791782
# PLOTS 2-3: Plot the eigenvalues (discrete-time and continuous-time).
792-
793-
# # Scale marker sizes to reflect the amount of variance captured.
794-
# ms_vals = max_eig_ms * np.sqrt(s_var / s_var[0])
795-
796783
# Scale marker sizes to reflect their associated amplitude.
797784
ms_vals = max_eig_ms * np.sqrt(lead_amplitudes / lead_amplitudes[0])
798785

799-
for i, (ax, eigs) in enumerate(zip(eig_axes[1:], [disc_eigs, cont_eigs])):
800-
# Plot the complex plane axes.
801-
ax.axvline(x=0, c="k", lw=1)
802-
ax.axhline(y=0, c="k", lw=1)
803-
ax.axis("equal")
804-
# PLOT 2: Plot the discrete-time eigenvalues on the unit circle.
805-
if i == 0:
806-
ax.set_title("Discrete-time Eigenvalues", fontsize=title_fontsize)
807-
t = np.linspace(0, 2 * np.pi, 100)
808-
ax.plot(np.cos(t), np.sin(t), c="tab:blue", ls="--")
809-
ax.set_xlabel(r"$Re(\lambda)$", fontsize=label_fontsize)
810-
ax.set_ylabel(r"$Im(\lambda)$", fontsize=label_fontsize)
811-
# PLOT 3: Plot the continuous-time eigenvalues.
812-
else:
813-
ax.set_title("Continuous-time Eigenvalues", fontsize=title_fontsize)
814-
ax.set_xlabel(r"$Im(\omega)$", fontsize=label_fontsize)
815-
ax.set_ylabel(r"$Re(\omega)$", fontsize=label_fontsize)
816-
# Plot the eigenvalues (discrete or continuous).
817-
if eigs is not None:
818-
for idx, eig in enumerate(eigs):
819-
if idx in index_modes:
820-
color = main_colors[index_modes.index(idx)]
821-
else:
822-
color = "gray"
823-
if i == 0:
824-
ax.plot(eig.real, eig.imag, "o", c=color, ms=ms_vals[idx])
825-
else:
826-
ax.plot(eig.imag, eig.real, "o", c=color, ms=ms_vals[idx])
786+
# PLOT 2: Plot the discrete-time eigenvalues on the unit circle.
787+
# Plot the complex plane axes.
788+
eig_axes[1].axvline(x=0, c="k", lw=1)
789+
eig_axes[1].axhline(y=0, c="k", lw=1)
790+
eig_axes[1].axis("equal")
791+
# Plot the unit circle.
792+
eig_axes[1].set_title("Discrete-time Eigenvalues", fontsize=title_fontsize)
793+
t = np.linspace(0, 2 * np.pi, 100)
794+
eig_axes[1].plot(np.cos(t), np.sin(t), c="tab:blue", ls="--")
795+
eig_axes[1].set_xlabel(r"$Re(\lambda)$", fontsize=label_fontsize)
796+
eig_axes[1].set_ylabel(r"$Im(\lambda)$", fontsize=label_fontsize)
797+
# Plot the eigenvalues.
798+
if disc_eigs is not None:
799+
for idx, eig in enumerate(disc_eigs):
800+
if idx in index_modes:
801+
color = main_colors[index_modes.index(idx)]
802+
else:
803+
color = "tab:orange"
804+
ax.plot(eig.real, eig.imag, "o", c=color, ms=ms_vals[idx], mec="k")
805+
806+
# PLOT 3: Plot the continuous-time eigenvalues.
807+
# Plot the complex plane axes.
808+
eig_axes[2].axvline(x=0, c="k", lw=1)
809+
eig_axes[2].axhline(y=0, c="k", lw=1)
810+
# eig_axes[2].axis("equal")
811+
eig_axes[2].set_title("Continuous-time Eigenvalues", fontsize=title_fontsize)
812+
eig_axes[2].set_xlabel(r"$Im(\omega)$", fontsize=label_fontsize)
813+
eig_axes[2].set_ylabel(r"$Re(\omega)$", fontsize=label_fontsize)
814+
eig_axes[2].invert_xaxis()
815+
# Plot the eigenvalues.
816+
if cont_eigs is not None:
817+
for idx, eig in enumerate(cont_eigs):
818+
if idx in index_modes:
819+
color = main_colors[index_modes.index(idx)]
820+
else:
821+
color = "tab:orange"
822+
ax.plot(eig.imag, eig.real, "o", c=color, ms=ms_vals[idx], mec="k")
827823

828824
# PLOTS 4-6: Plot the DMD modes.
825+
if imshow_kwargs is None:
826+
imshow_kwargs = {}
827+
if "cmap" not in imshow_kwargs:
828+
imshow_kwargs["cmap"] = "bwr"
829+
829830
for i, (ax, idx) in enumerate(zip(mode_axes, index_modes)):
830831
ax.set_title(
831832
f"Mode {idx + 1}", c=main_colors[i], fontsize=title_fontsize
832833
)
833834
# Plot modes in 1D.
834835
if len(snapshots_shape) == 1:
835-
ax.plot(lead_modes[:, idx].real, c=mode_color)
836+
if x is None:
837+
x = np.arange(len(lead_modes))
838+
ax.plot(x, lead_modes[:, idx].real, c="k")
836839
# Plot modes in 2D.
837840
else:
838841
mode = lead_modes[:, idx].reshape(*snapshots_shape, order=order)
839842
vmax = np.abs(mode.real).max()
840-
im = ax.imshow(mode.real, vmax=vmax, vmin=-vmax, cmap=mode_cmap)
843+
im = ax.imshow(mode.real, vmax=vmax, vmin=-vmax, **imshow_kwargs)
841844
# Align the colorbar with the plotted image.
842845
divider = make_axes_locatable(ax)
843846
cax = divider.append_axes("right", size="3%", pad=0.05)
844847
fig.colorbar(im, cax=cax)
845-
if remove_cmap_ticks:
846-
ax.set_xticks([])
847-
ax.set_yticks([])
848848

849849
# PLOTS 7-9: Plot the DMD mode dynamics.
850850
for i, (ax, idx) in enumerate(zip(dynamics_axes, index_modes)):
851851
dynamics_data = lead_dynamics[idx].real
852852
ax.set_title("Mode Dynamics", c=main_colors[i], fontsize=title_fontsize)
853-
ax.plot(time, dynamics_data, c=dynamics_color)
853+
ax.plot(time, dynamics_data, c="tab:blue")
854854
ax.set_xlabel("Time", fontsize=label_fontsize)
855-
dynamics_range = dynamics_data.max() - dynamics_data.min()
855+
856856
# Re-adjust ylim if dynamics oscillations are extremely small.
857+
dynamics_range = dynamics_data.max() - dynamics_data.min()
857858
if dynamics_range / np.abs(np.average(dynamics_data)) < 1e-4:
858859
ax.set_ylim(np.sort([0.0, 2 * np.average(dynamics_data)]))
859860

0 commit comments

Comments
 (0)