@@ -732,23 +732,27 @@ def plot_summary(
732732 else :
733733 # For all other dmd models, go to the TimeDict for time information,
734734 # that or use the user-provided time information in t if available.
735+ num_samples = dmd .snapshots .shape [- 1 ]
735736 if isinstance (t , (int , float )):
736- time = np .arange (dmd . snapshots . shape [ - 1 ] ) * t
737+ time = np .arange (num_samples ) * t
737738 dt = t
738739 elif t is not None :
739740 time = np .squeeze (np .array (t ))
740741 dt = time [1 ] - time [0 ]
741742 if not np .allclose (time [1 :] - time [:- 1 ], dt ):
742- raise ValueError ("Time step is not uniform. Check t vector." )
743+ warnings .warn (
744+ "Time step is not uniform. DMD might produce unexpected "
745+ "results. Consider using BOP-DMD instead."
746+ )
743747 else :
744748 try :
745749 time = dmd .original_timesteps
746750 dt = dmd .original_time ["dt" ]
747751 except AttributeError :
748752 warnings .warn (
749- "No time information available. " " Using dt = 1 and t0 = 0."
753+ "No time information available. Using dt = 1 and t0 = 0."
750754 )
751- time = np .arange (dmd . snapshots . shape [ - 1 ] )
755+ time = np .arange (num_samples )
752756 dt = 1.0
753757
754758 if continuous :
@@ -762,7 +766,9 @@ def plot_summary(
762766 if d > 1 :
763767 lead_modes = np .average (
764768 lead_modes .reshape (
765- d , lead_modes .shape [0 ] // d , lead_modes .shape [1 ]
769+ d ,
770+ lead_modes .shape [0 ] // d ,
771+ lead_modes .shape [1 ],
766772 ),
767773 axis = 0 ,
768774 )
@@ -779,31 +785,48 @@ def plot_summary(
779785 s_var = s * (100 / np .sum (s ))
780786 s_var = s_var [:max_sval_plot ]
781787
782- # Build a list of the complex conjugate pairs to be highlighted.
788+ # Build a list of indices of the complex conjugate pairs to highlight.
789+ # Example: If index_modes = [idx1, idx2, idx3, idx4], such that...
790+ # idx1 has no complex conjugate pair
791+ # idx2 and idx3 are complex conjugates
792+ # idx4 and idx5 are complex conjugates
793+ # Then index_modes_cc = [(idx1, idx1), (idx2, idx3), (idx4, idx5)]
783794 index_modes_cc = []
784- for idx1 in index_modes :
785- eig = cont_eigs [idx1 ]
786- idx2 = list (cont_eigs ).index (eig .conj ())
795+ for idx in index_modes :
796+ eig = cont_eigs [idx ]
787797 if eig .conj () not in cont_eigs :
788- index_modes_cc .append ((idx1 ,))
789- elif idx2 not in np .array (index_modes_cc ):
790- index_modes_cc .append ((idx1 , idx2 ))
798+ index_modes_cc .append ((idx ,))
799+ elif idx not in np .array (index_modes_cc ):
800+ index_modes_cc .append ((idx , list ( cont_eigs ). index ( eig . conj ()) ))
791801 other_eigs = np .setdiff1d (np .arange (rank ), np .array (index_modes_cc ))
792802
793803 # Generate the summarizing plot.
794804 fig , (eig_axes , mode_axes , dynamics_axes ) = plt .subplots (
795- 3 , 3 , figsize = figsize , dpi = dpi
805+ 3 ,
806+ 3 ,
807+ figsize = figsize ,
808+ dpi = dpi ,
796809 )
797810
798811 # PLOT 1: Plot the singular value spectrum.
799812 eig_axes [0 ].set_title ("Singular Values" , fontsize = title_fontsize )
800813 eig_axes [0 ].set_ylabel ("% variance" , fontsize = label_fontsize )
801814 s_t = np .arange (len (s_var )) + 1
802815 eig_axes [0 ].plot (
803- s_t [:rank ], s_var [:rank ], "o" , c = rank_color , ms = sval_ms , mec = "k"
816+ s_t [:rank ],
817+ s_var [:rank ],
818+ "o" ,
819+ c = rank_color ,
820+ ms = sval_ms ,
821+ mec = "k" ,
804822 )
805823 eig_axes [0 ].plot (
806- s_t [rank :], s_var [rank :], "o" , c = "gray" , ms = sval_ms , mec = "k"
824+ s_t [rank :],
825+ s_var [rank :],
826+ "o" ,
827+ c = "gray" ,
828+ ms = sval_ms ,
829+ mec = "k" ,
807830 )
808831 eig_axes [0 ].legend (
809832 handles = [Patch (facecolor = rank_color , label = "Rank of fit" )]
@@ -830,7 +853,8 @@ def plot_summary(
830853 eig_axes [2 ].axhline (y = 0 , c = "k" , lw = 1 )
831854 eig_axes [2 ].axis ("equal" )
832855 eig_axes [2 ].set_title (
833- "Continuous-time Eigenvalues" , fontsize = title_fontsize
856+ "Continuous-time Eigenvalues" ,
857+ fontsize = title_fontsize ,
834858 )
835859 if flip_continuous_axes :
836860 eig_axes [2 ].set_xlabel (r"$Im(\omega)$" , fontsize = label_fontsize )
@@ -845,6 +869,7 @@ def plot_summary(
845869 mode_colors = {}
846870 for ax , eigs in zip ([eig_axes [1 ], eig_axes [2 ]], [disc_eigs , cont_eigs ]):
847871 if eigs is not None :
872+ # Plot the main indices and their complex conjugate.
848873 for i , indices in enumerate (index_modes_cc ):
849874 for idx in indices :
850875 ax .plot (
@@ -856,6 +881,7 @@ def plot_summary(
856881 mec = "k" ,
857882 )
858883 mode_colors [idx ] = main_colors [i ]
884+ # Plot all other DMD eigenvalues.
859885 for idx in other_eigs :
860886 ax .plot (
861887 eigs [idx ].real ,
@@ -866,26 +892,35 @@ def plot_summary(
866892 mec = "k" ,
867893 )
868894
869- # PLOTS 4-6: Plot the DMD modes .
895+ # Build the spatial grid for the mode plots .
870896 if x is None :
871897 x = np .arange (snapshots_shape [0 ])
898+ if len (snapshots_shape ) == 2 :
899+ if y is None :
900+ y = np .arange (snapshots_shape [1 ])
901+ ygrid , xgrid = np .meshgrid (y , x )
872902
903+ # PLOTS 4-6: Plot the DMD modes.
873904 for i , (ax , idx ) in enumerate (zip (mode_axes , index_modes )):
874905 ax .set_title (
875- f"Mode { idx + 1 } " , c = mode_colors [idx ], fontsize = title_fontsize
906+ f"Mode { idx + 1 } " ,
907+ c = mode_colors [idx ],
908+ fontsize = title_fontsize ,
876909 )
877- # Plot modes in 1-D.
878910 if len (snapshots_shape ) == 1 :
911+ # Plot modes in 1-D.
879912 ax .plot (x , lead_modes [:, idx ].real , c = mode_color )
880- # Plot modes in 2-D.
881913 else :
882- if y is None :
883- y = np .arange (snapshots_shape [1 ])
884- ygrid , xgrid = np .meshgrid (y , x )
914+ # Plot modes in 2-D.
885915 mode = lead_modes [:, idx ].reshape (* snapshots_shape , order = order )
886916 vmax = np .abs (mode .real ).max ()
887917 im = ax .pcolormesh (
888- xgrid , ygrid , mode .real , vmax = vmax , vmin = - vmax , cmap = mode_cmap
918+ xgrid ,
919+ ygrid ,
920+ mode .real ,
921+ vmax = vmax ,
922+ vmin = - vmax ,
923+ cmap = mode_cmap ,
889924 )
890925 # Align the colorbar with the plotted image.
891926 divider = make_axes_locatable (ax )
@@ -896,7 +931,9 @@ def plot_summary(
896931 for i , (ax , idx ) in enumerate (zip (dynamics_axes , index_modes )):
897932 dynamics_data = lead_dynamics [idx ].real
898933 ax .set_title (
899- "Mode Dynamics" , c = mode_colors [idx ], fontsize = title_fontsize
934+ "Mode Dynamics" ,
935+ c = mode_colors [idx ],
936+ fontsize = title_fontsize ,
900937 )
901938 ax .plot (time , dynamics_data , c = dynamics_color )
902939 ax .set_xlabel ("Time" , fontsize = label_fontsize )
0 commit comments