Skip to content

Commit b552fd2

Browse files
committed
Function reformatting
1 parent 5432378 commit b552fd2

File tree

1 file changed

+62
-25
lines changed

1 file changed

+62
-25
lines changed

pydmd/plotter.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -732,23 +732,27 @@ def plot_summary(
732732
else:
733733
# For all other dmd models, go to the TimeDict for time information,
734734
# that or use the user-provided time information in t if available.
735+
num_samples = dmd.snapshots.shape[-1]
735736
if isinstance(t, (int, float)):
736-
time = np.arange(dmd.snapshots.shape[-1]) * t
737+
time = np.arange(num_samples) * t
737738
dt = t
738739
elif t is not None:
739740
time = np.squeeze(np.array(t))
740741
dt = time[1] - time[0]
741742
if not np.allclose(time[1:] - time[:-1], dt):
742-
raise ValueError("Time step is not uniform. Check t vector.")
743+
warnings.warn(
744+
"Time step is not uniform. DMD might produce unexpected "
745+
"results. Consider using BOP-DMD instead."
746+
)
743747
else:
744748
try:
745749
time = dmd.original_timesteps
746750
dt = dmd.original_time["dt"]
747751
except AttributeError:
748752
warnings.warn(
749-
"No time information available. " "Using dt = 1 and t0 = 0."
753+
"No time information available. Using dt = 1 and t0 = 0."
750754
)
751-
time = np.arange(dmd.snapshots.shape[-1])
755+
time = np.arange(num_samples)
752756
dt = 1.0
753757

754758
if continuous:
@@ -762,7 +766,9 @@ def plot_summary(
762766
if d > 1:
763767
lead_modes = np.average(
764768
lead_modes.reshape(
765-
d, lead_modes.shape[0] // d, lead_modes.shape[1]
769+
d,
770+
lead_modes.shape[0] // d,
771+
lead_modes.shape[1],
766772
),
767773
axis=0,
768774
)
@@ -779,31 +785,48 @@ def plot_summary(
779785
s_var = s * (100 / np.sum(s))
780786
s_var = s_var[:max_sval_plot]
781787

782-
# Build a list of the complex conjugate pairs to be highlighted.
788+
# Build a list of indices of the complex conjugate pairs to highlight.
789+
# Example: If index_modes = [idx1, idx2, idx3, idx4], such that...
790+
# idx1 has no complex conjugate pair
791+
# idx2 and idx3 are complex conjugates
792+
# idx4 and idx5 are complex conjugates
793+
# Then index_modes_cc = [(idx1, idx1), (idx2, idx3), (idx4, idx5)]
783794
index_modes_cc = []
784-
for idx1 in index_modes:
785-
eig = cont_eigs[idx1]
786-
idx2 = list(cont_eigs).index(eig.conj())
795+
for idx in index_modes:
796+
eig = cont_eigs[idx]
787797
if eig.conj() not in cont_eigs:
788-
index_modes_cc.append((idx1,))
789-
elif idx2 not in np.array(index_modes_cc):
790-
index_modes_cc.append((idx1, idx2))
798+
index_modes_cc.append((idx,))
799+
elif idx not in np.array(index_modes_cc):
800+
index_modes_cc.append((idx, list(cont_eigs).index(eig.conj())))
791801
other_eigs = np.setdiff1d(np.arange(rank), np.array(index_modes_cc))
792802

793803
# Generate the summarizing plot.
794804
fig, (eig_axes, mode_axes, dynamics_axes) = plt.subplots(
795-
3, 3, figsize=figsize, dpi=dpi
805+
3,
806+
3,
807+
figsize=figsize,
808+
dpi=dpi,
796809
)
797810

798811
# PLOT 1: Plot the singular value spectrum.
799812
eig_axes[0].set_title("Singular Values", fontsize=title_fontsize)
800813
eig_axes[0].set_ylabel("% variance", fontsize=label_fontsize)
801814
s_t = np.arange(len(s_var)) + 1
802815
eig_axes[0].plot(
803-
s_t[:rank], s_var[:rank], "o", c=rank_color, ms=sval_ms, mec="k"
816+
s_t[:rank],
817+
s_var[:rank],
818+
"o",
819+
c=rank_color,
820+
ms=sval_ms,
821+
mec="k",
804822
)
805823
eig_axes[0].plot(
806-
s_t[rank:], s_var[rank:], "o", c="gray", ms=sval_ms, mec="k"
824+
s_t[rank:],
825+
s_var[rank:],
826+
"o",
827+
c="gray",
828+
ms=sval_ms,
829+
mec="k",
807830
)
808831
eig_axes[0].legend(
809832
handles=[Patch(facecolor=rank_color, label="Rank of fit")]
@@ -830,7 +853,8 @@ def plot_summary(
830853
eig_axes[2].axhline(y=0, c="k", lw=1)
831854
eig_axes[2].axis("equal")
832855
eig_axes[2].set_title(
833-
"Continuous-time Eigenvalues", fontsize=title_fontsize
856+
"Continuous-time Eigenvalues",
857+
fontsize=title_fontsize,
834858
)
835859
if flip_continuous_axes:
836860
eig_axes[2].set_xlabel(r"$Im(\omega)$", fontsize=label_fontsize)
@@ -845,6 +869,7 @@ def plot_summary(
845869
mode_colors = {}
846870
for ax, eigs in zip([eig_axes[1], eig_axes[2]], [disc_eigs, cont_eigs]):
847871
if eigs is not None:
872+
# Plot the main indices and their complex conjugate.
848873
for i, indices in enumerate(index_modes_cc):
849874
for idx in indices:
850875
ax.plot(
@@ -856,6 +881,7 @@ def plot_summary(
856881
mec="k",
857882
)
858883
mode_colors[idx] = main_colors[i]
884+
# Plot all other DMD eigenvalues.
859885
for idx in other_eigs:
860886
ax.plot(
861887
eigs[idx].real,
@@ -866,26 +892,35 @@ def plot_summary(
866892
mec="k",
867893
)
868894

869-
# PLOTS 4-6: Plot the DMD modes.
895+
# Build the spatial grid for the mode plots.
870896
if x is None:
871897
x = np.arange(snapshots_shape[0])
898+
if len(snapshots_shape) == 2:
899+
if y is None:
900+
y = np.arange(snapshots_shape[1])
901+
ygrid, xgrid = np.meshgrid(y, x)
872902

903+
# PLOTS 4-6: Plot the DMD modes.
873904
for i, (ax, idx) in enumerate(zip(mode_axes, index_modes)):
874905
ax.set_title(
875-
f"Mode {idx + 1}", c=mode_colors[idx], fontsize=title_fontsize
906+
f"Mode {idx + 1}",
907+
c=mode_colors[idx],
908+
fontsize=title_fontsize,
876909
)
877-
# Plot modes in 1-D.
878910
if len(snapshots_shape) == 1:
911+
# Plot modes in 1-D.
879912
ax.plot(x, lead_modes[:, idx].real, c=mode_color)
880-
# Plot modes in 2-D.
881913
else:
882-
if y is None:
883-
y = np.arange(snapshots_shape[1])
884-
ygrid, xgrid = np.meshgrid(y, x)
914+
# Plot modes in 2-D.
885915
mode = lead_modes[:, idx].reshape(*snapshots_shape, order=order)
886916
vmax = np.abs(mode.real).max()
887917
im = ax.pcolormesh(
888-
xgrid, ygrid, mode.real, vmax=vmax, vmin=-vmax, cmap=mode_cmap
918+
xgrid,
919+
ygrid,
920+
mode.real,
921+
vmax=vmax,
922+
vmin=-vmax,
923+
cmap=mode_cmap,
889924
)
890925
# Align the colorbar with the plotted image.
891926
divider = make_axes_locatable(ax)
@@ -896,7 +931,9 @@ def plot_summary(
896931
for i, (ax, idx) in enumerate(zip(dynamics_axes, index_modes)):
897932
dynamics_data = lead_dynamics[idx].real
898933
ax.set_title(
899-
"Mode Dynamics", c=mode_colors[idx], fontsize=title_fontsize
934+
"Mode Dynamics",
935+
c=mode_colors[idx],
936+
fontsize=title_fontsize,
900937
)
901938
ax.plot(time, dynamics_data, c=dynamics_color)
902939
ax.set_xlabel("Time", fontsize=label_fontsize)

0 commit comments

Comments
 (0)