Skip to content

Commit b7d8ca1

Browse files
committed
More customization
1 parent 6651550 commit b7d8ca1

File tree

1 file changed

+109
-74
lines changed

1 file changed

+109
-74
lines changed

pydmd/plotter.py

Lines changed: 109 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,7 @@ def plot_summary(
537537
dmd,
538538
*,
539539
x=None,
540+
y=None,
540541
t=None,
541542
d=1,
542543
continuous=False,
@@ -548,13 +549,18 @@ def plot_summary(
548549
dpi=200,
549550
tight_layout_kwargs=None,
550551
main_colors=("r", "b", "g"),
551-
imshow_kwargs=None,
552+
mode_color="k",
553+
mode_cmap="bwr",
554+
dynamics_color="tab:blue",
555+
rank_color="tab:orange",
556+
circle_color="tab:blue",
552557
sval_ms=8,
553-
max_eig_ms=12,
558+
max_eig_ms=10,
554559
max_sval_plot=50,
555560
title_fontsize=14,
556561
label_fontsize=12,
557562
plot_semilogy=False,
563+
flip_continuous_axes=False,
558564
):
559565
"""
560566
Generate a 3 x 3 summarizing plot that contains the following components:
@@ -570,9 +576,13 @@ def plot_summary(
570576
571577
:param dmd: fitted DMD instance.
572578
:type dmd: pydmd.DMDBase
573-
:param x: The points in space where the data has been collected. Note that
574-
this parameter is currently only used for plotting modes that are 1-D.
579+
:param x: Points along the 1st spatial dimension where data has been
580+
collected.
575581
:type x: np.ndarray or iterable
582+
:param y: Points along the 2nd spatial dimension where data has been
583+
collected. Note that this parameter is only applicable when the data
584+
snapshots are 2-D, which must be indicated with `snapshots_shape`.
585+
:type y: np.ndarray or iterable
576586
:param t: The times of data collection, or the time-step between snapshots.
577587
Note that time information must be accurate in order to accurately
578588
visualize eigenvalues and times of the dynamics. For non-`BOPDMD`
@@ -592,7 +602,7 @@ def plot_summary(
592602
:type continuous: bool
593603
:param snapshots_shape: Shape of the snapshots. If not provided, the shape
594604
of the snapshots and modes is assumed to be the flattened space dim of
595-
the snapshot data.
605+
the snapshot data. Provide as width, height dimension.
596606
:type snapshots_shape: iterable
597607
:param index_modes: Indices of the modes to plot after they have been
598608
sorted based on significance. At most three may be provided.
@@ -617,16 +627,22 @@ def plot_summary(
617627
:type figsize: iterable
618628
:param dpi: Figure resolution.
619629
:type dpi: int
620-
:param tight_layout_kwargs: Optional dictionary of
621-
`matplotlib.pyplot.tight_layout` parameters.
630+
:param tight_layout_kwargs: Dictionary of `tight_layout` parameters.
622631
:type tight_layout_kwargs: dict
623-
:param main_colors: Strings defining the colors used to denote eigenvalue,
624-
mode, dynamics associations.
632+
:param main_colors: Colors used to denote eigenvalue, mode, dynamics
633+
associations.
625634
:type main_colors: iterable
626-
:param imshow_kwargs: Optional dictionary of `matplotlib.pyplot.imshow`
627-
parameters. Use this dictionary to re-define the parameters of 2-D
628-
mode plots.
629-
:type imshow_kwargs: dict
635+
:param mode_color: Color used to plot the modes, if modes are 1-D.
636+
:type mode_color: str
637+
:param mode_cmap: Colormap used to plot the modes, if modes are 2-D.
638+
:type mode_cmap: str
639+
:param dynamics_color: Color used to plot the dynamics.
640+
:type dynamics_color: str
641+
:param rank_color: Color used to highlight the rank of the DMD fit and
642+
all DMD eigenvalues aside from those highlighted by `index_modes`.
643+
:type rank_color: str
644+
:param circle_color: Color used to plot the unit circle.
645+
:type circle_color: str
630646
:param sval_ms: Marker size of all singular values.
631647
:type sval_ms: int
632648
:param max_eig_ms: Marker size of the most prominent eigenvalue. The marker
@@ -642,6 +658,10 @@ def plot_summary(
642658
:param plot_semilogy: Whether or not to plot the singular values on a
643659
semilogy plot. If `True`, a semilogy plot is used.
644660
:type plot_semilogy: bool
661+
:param flip_continuous_axes: Whether or not to swap the real and imaginary
662+
axes on the continuous eigenvalues plot. If `True`, the real axis will
663+
be vertical and the imaginary axis will be horizontal, and vice versa.
664+
:type flip_continuous_axes: bool
645665
"""
646666

647667
# This plotting method is inappropriate for plotting HAVOK results.
@@ -650,21 +670,30 @@ def plot_summary(
650670

651671
# Check that the DMD instance has been fitted.
652672
if dmd.modes is None:
653-
raise ValueError(
654-
"The modes have not been computed."
655-
"You need to perform fit() first."
656-
)
673+
raise ValueError("You need to perform fit() first.")
657674

658675
# By default, snapshots_shape is the flattened space dimension.
659676
if snapshots_shape is None:
660-
snapshots_shape = (len(dmd.snapshots),)
677+
snapshots_shape = (len(dmd.snapshots) // d,)
661678
# If provided, snapshots_shape must contain 2 entires.
662679
elif len(snapshots_shape) != 2:
663-
raise ValueError("snapshots_shape must be None or 2D.")
680+
raise ValueError("snapshots_shape must be None or 2-D.")
681+
682+
# Check the length of index_modes.
683+
if len(index_modes) > 3:
684+
raise ValueError("index_modes must have a length of at most 3.")
664685

665686
# Get the actual rank used for the DMD fit.
666687
rank = len(dmd.eigs)
667688

689+
# Ensure that at least rank-many singular values will be plotted.
690+
if rank > max_sval_plot:
691+
raise ValueError(f"max_sval_plot must be at least {rank}.")
692+
693+
# Indices cannot go past the total number of available modes.
694+
if np.any(np.array(index_modes) >= rank):
695+
raise ValueError(f"Cannot view past mode {rank}.")
696+
668697
# Override index_modes if there are less than 3 modes available.
669698
if rank < 3:
670699
warnings.warn(
@@ -673,16 +702,6 @@ def plot_summary(
673702
)
674703
index_modes = np.arange(rank)
675704

676-
# Check the length of index_modes.
677-
if len(index_modes) > 3:
678-
raise ValueError("index_modes must have a length of at most 3.")
679-
680-
# Indices cannot go past the total number of available or plottable modes.
681-
if np.any(np.array(index_modes) >= min(rank, max_sval_plot)):
682-
raise ValueError(
683-
f"Cannot view past mode {min(rank, max_sval_plot)}."
684-
)
685-
686705
# Sort eigenvalues, modes, and dynamics according to amplitude magnitude.
687706
mode_order = np.argsort(-np.abs(dmd.amplitudes))
688707
lead_eigs = dmd.eigs[mode_order]
@@ -692,7 +711,7 @@ def plot_summary(
692711

693712
# Get time information for eigenvalue conversions.
694713
# The decisions that we make here depend on if we're dealing
695-
# with a BOPDMD model or any other type of DMD model.
714+
# with a BOPDMD model or any other type of DMD model...
696715
if isinstance(dmd, BOPDMD) or (
697716
isinstance(dmd, PrePostProcessingDMD)
698717
and isinstance(dmd.pre_post_processed_dmd, BOPDMD)
@@ -717,16 +736,17 @@ def plot_summary(
717736
time = np.arange(dmd.snapshots.shape[-1]) * t
718737
dt = t
719738
elif t is not None:
720-
# Note: assumes uniform spacing in the provided time vector.
721739
time = np.squeeze(np.array(t))
722740
dt = time[1] - time[0]
741+
if not np.allclose(time[1:] - time[:-1], dt):
742+
raise ValueError("Time step is not uniform. Check t vector.")
723743
else:
724744
try:
725745
time = dmd.original_timesteps
726746
dt = dmd.original_time["dt"]
727747
except AttributeError:
728748
warnings.warn(
729-
"No time step information available. "
749+
"No time information available. "
730750
"Using dt = 1 and t0 = 0."
731751
)
732752
time = np.arange(dmd.snapshots.shape[-1])
@@ -760,6 +780,17 @@ def plot_summary(
760780
s_var = s * (100 / np.sum(s))
761781
s_var = s_var[:max_sval_plot]
762782

783+
# Build a list of the complex conjugate pairs to be highlighted.
784+
index_modes_cc = []
785+
for idx1 in index_modes:
786+
eig = cont_eigs[idx1]
787+
idx2 = list(cont_eigs).index(eig.conj())
788+
if eig.conj() not in cont_eigs:
789+
index_modes_cc.append((idx1,))
790+
elif idx2 not in np.array(index_modes_cc):
791+
index_modes_cc.append((idx1, idx2))
792+
other_eigs = np.setdiff1d(np.arange(rank), np.array(index_modes_cc))
793+
763794
# Generate the summarizing plot.
764795
fig, (eig_axes, mode_axes, dynamics_axes) = plt.subplots(
765796
3, 3, figsize=figsize, dpi=dpi
@@ -769,12 +800,14 @@ def plot_summary(
769800
eig_axes[0].set_title("Singular Values", fontsize=title_fontsize)
770801
eig_axes[0].set_ylabel("% variance", fontsize=label_fontsize)
771802
s_t = np.arange(len(s_var)) + 1
772-
eig_axes[0].plot(s_t, s_var, "o", c="gray", ms=sval_ms, mec="k")
773803
eig_axes[0].plot(
774-
s_t[:rank], s_var[:rank], "o", c="tab:orange", ms=sval_ms, mec="k"
804+
s_t[:rank], s_var[:rank], "o", c=rank_color, ms=sval_ms, mec="k"
805+
)
806+
eig_axes[0].plot(
807+
s_t[rank:], s_var[rank:], "o", c="gray", ms=sval_ms, mec="k"
775808
)
776809
eig_axes[0].legend(
777-
handles=[Patch(facecolor="tab:orange", label="Rank of fit")]
810+
handles=[Patch(facecolor=rank_color, label="Rank of fit")]
778811
)
779812
if plot_semilogy:
780813
eig_axes[0].semilogy()
@@ -784,63 +817,65 @@ def plot_summary(
784817
ms_vals = max_eig_ms * np.sqrt(lead_amplitudes / lead_amplitudes[0])
785818

786819
# PLOT 2: Plot the discrete-time eigenvalues on the unit circle.
787-
# Plot the complex plane axes.
788820
eig_axes[1].axvline(x=0, c="k", lw=1)
789821
eig_axes[1].axhline(y=0, c="k", lw=1)
790822
eig_axes[1].axis("equal")
791-
# Plot the unit circle.
792823
eig_axes[1].set_title("Discrete-time Eigenvalues", fontsize=title_fontsize)
793824
t = np.linspace(0, 2 * np.pi, 100)
794-
eig_axes[1].plot(np.cos(t), np.sin(t), c="tab:blue", ls="--")
825+
eig_axes[1].plot(np.cos(t), np.sin(t), c=circle_color, ls="--")
795826
eig_axes[1].set_xlabel(r"$Re(\lambda)$", fontsize=label_fontsize)
796827
eig_axes[1].set_ylabel(r"$Im(\lambda)$", fontsize=label_fontsize)
797-
# Plot the eigenvalues.
798-
if disc_eigs is not None:
799-
for idx, eig in enumerate(disc_eigs):
800-
if idx in index_modes:
801-
color = main_colors[index_modes.index(idx)]
802-
else:
803-
color = "tab:orange"
804-
ax.plot(eig.real, eig.imag, "o", c=color, ms=ms_vals[idx], mec="k")
805828

806829
# PLOT 3: Plot the continuous-time eigenvalues.
807-
# Plot the complex plane axes.
808830
eig_axes[2].axvline(x=0, c="k", lw=1)
809831
eig_axes[2].axhline(y=0, c="k", lw=1)
810-
# eig_axes[2].axis("equal")
832+
eig_axes[2].axis("equal")
811833
eig_axes[2].set_title("Continuous-time Eigenvalues", fontsize=title_fontsize)
812-
eig_axes[2].set_xlabel(r"$Im(\omega)$", fontsize=label_fontsize)
813-
eig_axes[2].set_ylabel(r"$Re(\omega)$", fontsize=label_fontsize)
814-
eig_axes[2].invert_xaxis()
815-
# Plot the eigenvalues.
816-
if cont_eigs is not None:
817-
for idx, eig in enumerate(cont_eigs):
818-
if idx in index_modes:
819-
color = main_colors[index_modes.index(idx)]
820-
else:
821-
color = "tab:orange"
822-
ax.plot(eig.imag, eig.real, "o", c=color, ms=ms_vals[idx], mec="k")
834+
if flip_continuous_axes:
835+
eig_axes[2].set_xlabel(r"$Im(\omega)$", fontsize=label_fontsize)
836+
eig_axes[2].set_ylabel(r"$Re(\omega)$", fontsize=label_fontsize)
837+
eig_axes[2].invert_xaxis()
838+
cont_eigs = 1j * cont_eigs.real + cont_eigs.imag
839+
else:
840+
eig_axes[2].set_xlabel(r"$Re(\omega)$", fontsize=label_fontsize)
841+
eig_axes[2].set_ylabel(r"$Im(\omega)$", fontsize=label_fontsize)
842+
843+
# Now plot the eigenvalues and record the colors used for each main index.
844+
mode_colors = {}
845+
for ax, eigs in zip([eig_axes[1], eig_axes[2]], [disc_eigs, cont_eigs]):
846+
if eigs is not None:
847+
for i, indices in enumerate(index_modes_cc):
848+
for idx in indices:
849+
ax.plot(
850+
eigs[idx].real,
851+
eigs[idx].imag,
852+
"o", c=main_colors[i], ms=ms_vals[idx], mec="k",
853+
)
854+
mode_colors[idx] = main_colors[i]
855+
for idx in other_eigs:
856+
ax.plot(
857+
eigs[idx].real,
858+
eigs[idx].imag,
859+
"o", c=rank_color, ms=ms_vals[idx], mec="k",
860+
)
823861

824862
# PLOTS 4-6: Plot the DMD modes.
825-
if imshow_kwargs is None:
826-
imshow_kwargs = {}
827-
if "cmap" not in imshow_kwargs:
828-
imshow_kwargs["cmap"] = "bwr"
863+
if x is None:
864+
x = np.arange(snapshots_shape[0])
829865

830866
for i, (ax, idx) in enumerate(zip(mode_axes, index_modes)):
831-
ax.set_title(
832-
f"Mode {idx + 1}", c=main_colors[i], fontsize=title_fontsize
833-
)
834-
# Plot modes in 1D.
867+
ax.set_title(f"Mode {idx + 1}", c=mode_colors[idx], fontsize=title_fontsize)
868+
# Plot modes in 1-D.
835869
if len(snapshots_shape) == 1:
836-
if x is None:
837-
x = np.arange(len(lead_modes))
838-
ax.plot(x, lead_modes[:, idx].real, c="k")
839-
# Plot modes in 2D.
870+
ax.plot(x, lead_modes[:, idx].real, c=mode_color)
871+
# Plot modes in 2-D.
840872
else:
873+
if y is None:
874+
y = np.arange(snapshots_shape[1])
875+
ygrid, xgrid = np.meshgrid(y, x)
841876
mode = lead_modes[:, idx].reshape(*snapshots_shape, order=order)
842877
vmax = np.abs(mode.real).max()
843-
im = ax.imshow(mode.real, vmax=vmax, vmin=-vmax, **imshow_kwargs)
878+
im = ax.pcolormesh(xgrid, ygrid, mode.real, vmax=vmax, vmin=-vmax, cmap=mode_cmap)
844879
# Align the colorbar with the plotted image.
845880
divider = make_axes_locatable(ax)
846881
cax = divider.append_axes("right", size="3%", pad=0.05)
@@ -849,8 +884,8 @@ def plot_summary(
849884
# PLOTS 7-9: Plot the DMD mode dynamics.
850885
for i, (ax, idx) in enumerate(zip(dynamics_axes, index_modes)):
851886
dynamics_data = lead_dynamics[idx].real
852-
ax.set_title("Mode Dynamics", c=main_colors[i], fontsize=title_fontsize)
853-
ax.plot(time, dynamics_data, c="tab:blue")
887+
ax.set_title("Mode Dynamics", c=mode_colors[idx], fontsize=title_fontsize)
888+
ax.plot(time, dynamics_data, c=dynamics_color)
854889
ax.set_xlabel("Time", fontsize=label_fontsize)
855890

856891
# Re-adjust ylim if dynamics oscillations are extremely small.

0 commit comments

Comments
 (0)