@@ -537,6 +537,7 @@ def plot_summary(
537537 dmd ,
538538 * ,
539539 x = None ,
540+ y = None ,
540541 t = None ,
541542 d = 1 ,
542543 continuous = False ,
@@ -548,13 +549,18 @@ def plot_summary(
548549 dpi = 200 ,
549550 tight_layout_kwargs = None ,
550551 main_colors = ("r" , "b" , "g" ),
551- imshow_kwargs = None ,
552+ mode_color = "k" ,
553+ mode_cmap = "bwr" ,
554+ dynamics_color = "tab:blue" ,
555+ rank_color = "tab:orange" ,
556+ circle_color = "tab:blue" ,
552557 sval_ms = 8 ,
553- max_eig_ms = 12 ,
558+ max_eig_ms = 10 ,
554559 max_sval_plot = 50 ,
555560 title_fontsize = 14 ,
556561 label_fontsize = 12 ,
557562 plot_semilogy = False ,
563+ flip_continuous_axes = False ,
558564):
559565 """
560566 Generate a 3 x 3 summarizing plot that contains the following components:
@@ -570,9 +576,13 @@ def plot_summary(
570576
571577 :param dmd: fitted DMD instance.
572578 :type dmd: pydmd.DMDBase
573- :param x: The points in space where the data has been collected. Note that
574- this parameter is currently only used for plotting modes that are 1-D .
579+ :param x: Points along the 1st spatial dimension where data has been
580+ collected .
575581 :type x: np.ndarray or iterable
582+ :param y: Points along the 2nd spatial dimension where data has been
583+ collected. Note that this parameter is only applicable when the data
584+ snapshots are 2-D, which must be indicated with `snapshots_shape`.
585+ :type y: np.ndarray or iterable
576586 :param t: The times of data collection, or the time-step between snapshots.
577587 Note that time information must be accurate in order to accurately
578588 visualize eigenvalues and times of the dynamics. For non-`BOPDMD`
@@ -592,7 +602,7 @@ def plot_summary(
592602 :type continuous: bool
593603 :param snapshots_shape: Shape of the snapshots. If not provided, the shape
594604 of the snapshots and modes is assumed to be the flattened space dim of
595- the snapshot data.
605+ the snapshot data. Provide as width, height dimension.
596606 :type snapshots_shape: iterable
597607 :param index_modes: Indices of the modes to plot after they have been
598608 sorted based on significance. At most three may be provided.
@@ -617,16 +627,22 @@ def plot_summary(
617627 :type figsize: iterable
618628 :param dpi: Figure resolution.
619629 :type dpi: int
620- :param tight_layout_kwargs: Optional dictionary of
621- `matplotlib.pyplot.tight_layout` parameters.
630+ :param tight_layout_kwargs: Dictionary of `tight_layout` parameters.
622631 :type tight_layout_kwargs: dict
623- :param main_colors: Strings defining the colors used to denote eigenvalue,
624- mode, dynamics associations.
632+ :param main_colors: Colors used to denote eigenvalue, mode, dynamics
633+ associations.
625634 :type main_colors: iterable
626- :param imshow_kwargs: Optional dictionary of `matplotlib.pyplot.imshow`
627- parameters. Use this dictionary to re-define the parameters of 2-D
628- mode plots.
629- :type imshow_kwargs: dict
635+ :param mode_color: Color used to plot the modes, if modes are 1-D.
636+ :type mode_color: str
637+ :param mode_cmap: Colormap used to plot the modes, if modes are 2-D.
638+ :type mode_cmap: str
639+ :param dynamics_color: Color used to plot the dynamics.
640+ :type dynamics_color: str
641+ :param rank_color: Color used to highlight the rank of the DMD fit and
642+ all DMD eigenvalues aside from those highlighted by `index_modes`.
643+ :type rank_color: str
644+ :param circle_color: Color used to plot the unit circle.
645+ :type circle_color: str
630646 :param sval_ms: Marker size of all singular values.
631647 :type sval_ms: int
632648 :param max_eig_ms: Marker size of the most prominent eigenvalue. The marker
@@ -642,6 +658,10 @@ def plot_summary(
642658 :param plot_semilogy: Whether or not to plot the singular values on a
643659 semilogy plot. If `True`, a semilogy plot is used.
644660 :type plot_semilogy: bool
661+ :param flip_continuous_axes: Whether or not to swap the real and imaginary
662+ axes on the continuous eigenvalues plot. If `True`, the real axis will
663+ be vertical and the imaginary axis will be horizontal, and vice versa.
664+ :type flip_continuous_axes: bool
645665 """
646666
647667 # This plotting method is inappropriate for plotting HAVOK results.
@@ -650,21 +670,30 @@ def plot_summary(
650670
651671 # Check that the DMD instance has been fitted.
652672 if dmd .modes is None :
653- raise ValueError (
654- "The modes have not been computed."
655- "You need to perform fit() first."
656- )
673+ raise ValueError ("You need to perform fit() first." )
657674
658675 # By default, snapshots_shape is the flattened space dimension.
659676 if snapshots_shape is None :
660- snapshots_shape = (len (dmd .snapshots ),)
677+ snapshots_shape = (len (dmd .snapshots ) // d ,)
661678 # If provided, snapshots_shape must contain 2 entires.
662679 elif len (snapshots_shape ) != 2 :
663- raise ValueError ("snapshots_shape must be None or 2D." )
680+ raise ValueError ("snapshots_shape must be None or 2-D." )
681+
682+ # Check the length of index_modes.
683+ if len (index_modes ) > 3 :
684+ raise ValueError ("index_modes must have a length of at most 3." )
664685
665686 # Get the actual rank used for the DMD fit.
666687 rank = len (dmd .eigs )
667688
689+ # Ensure that at least rank-many singular values will be plotted.
690+ if rank > max_sval_plot :
691+ raise ValueError (f"max_sval_plot must be at least { rank } ." )
692+
693+ # Indices cannot go past the total number of available modes.
694+ if np .any (np .array (index_modes ) >= rank ):
695+ raise ValueError (f"Cannot view past mode { rank } ." )
696+
668697 # Override index_modes if there are less than 3 modes available.
669698 if rank < 3 :
670699 warnings .warn (
@@ -673,16 +702,6 @@ def plot_summary(
673702 )
674703 index_modes = np .arange (rank )
675704
676- # Check the length of index_modes.
677- if len (index_modes ) > 3 :
678- raise ValueError ("index_modes must have a length of at most 3." )
679-
680- # Indices cannot go past the total number of available or plottable modes.
681- if np .any (np .array (index_modes ) >= min (rank , max_sval_plot )):
682- raise ValueError (
683- f"Cannot view past mode { min (rank , max_sval_plot )} ."
684- )
685-
686705 # Sort eigenvalues, modes, and dynamics according to amplitude magnitude.
687706 mode_order = np .argsort (- np .abs (dmd .amplitudes ))
688707 lead_eigs = dmd .eigs [mode_order ]
@@ -692,7 +711,7 @@ def plot_summary(
692711
693712 # Get time information for eigenvalue conversions.
694713 # The decisions that we make here depend on if we're dealing
695- # with a BOPDMD model or any other type of DMD model.
714+ # with a BOPDMD model or any other type of DMD model...
696715 if isinstance (dmd , BOPDMD ) or (
697716 isinstance (dmd , PrePostProcessingDMD )
698717 and isinstance (dmd .pre_post_processed_dmd , BOPDMD )
@@ -717,16 +736,17 @@ def plot_summary(
717736 time = np .arange (dmd .snapshots .shape [- 1 ]) * t
718737 dt = t
719738 elif t is not None :
720- # Note: assumes uniform spacing in the provided time vector.
721739 time = np .squeeze (np .array (t ))
722740 dt = time [1 ] - time [0 ]
741+ if not np .allclose (time [1 :] - time [:- 1 ], dt ):
742+ raise ValueError ("Time step is not uniform. Check t vector." )
723743 else :
724744 try :
725745 time = dmd .original_timesteps
726746 dt = dmd .original_time ["dt" ]
727747 except AttributeError :
728748 warnings .warn (
729- "No time step information available. "
749+ "No time information available. "
730750 "Using dt = 1 and t0 = 0."
731751 )
732752 time = np .arange (dmd .snapshots .shape [- 1 ])
@@ -760,6 +780,17 @@ def plot_summary(
760780 s_var = s * (100 / np .sum (s ))
761781 s_var = s_var [:max_sval_plot ]
762782
783+ # Build a list of the complex conjugate pairs to be highlighted.
784+ index_modes_cc = []
785+ for idx1 in index_modes :
786+ eig = cont_eigs [idx1 ]
787+ idx2 = list (cont_eigs ).index (eig .conj ())
788+ if eig .conj () not in cont_eigs :
789+ index_modes_cc .append ((idx1 ,))
790+ elif idx2 not in np .array (index_modes_cc ):
791+ index_modes_cc .append ((idx1 , idx2 ))
792+ other_eigs = np .setdiff1d (np .arange (rank ), np .array (index_modes_cc ))
793+
763794 # Generate the summarizing plot.
764795 fig , (eig_axes , mode_axes , dynamics_axes ) = plt .subplots (
765796 3 , 3 , figsize = figsize , dpi = dpi
@@ -769,12 +800,14 @@ def plot_summary(
769800 eig_axes [0 ].set_title ("Singular Values" , fontsize = title_fontsize )
770801 eig_axes [0 ].set_ylabel ("% variance" , fontsize = label_fontsize )
771802 s_t = np .arange (len (s_var )) + 1
772- eig_axes [0 ].plot (s_t , s_var , "o" , c = "gray" , ms = sval_ms , mec = "k" )
773803 eig_axes [0 ].plot (
774- s_t [:rank ], s_var [:rank ], "o" , c = "tab:orange" , ms = sval_ms , mec = "k"
804+ s_t [:rank ], s_var [:rank ], "o" , c = rank_color , ms = sval_ms , mec = "k"
805+ )
806+ eig_axes [0 ].plot (
807+ s_t [rank :], s_var [rank :], "o" , c = "gray" , ms = sval_ms , mec = "k"
775808 )
776809 eig_axes [0 ].legend (
777- handles = [Patch (facecolor = "tab:orange" , label = "Rank of fit" )]
810+ handles = [Patch (facecolor = rank_color , label = "Rank of fit" )]
778811 )
779812 if plot_semilogy :
780813 eig_axes [0 ].semilogy ()
@@ -784,63 +817,65 @@ def plot_summary(
784817 ms_vals = max_eig_ms * np .sqrt (lead_amplitudes / lead_amplitudes [0 ])
785818
786819 # PLOT 2: Plot the discrete-time eigenvalues on the unit circle.
787- # Plot the complex plane axes.
788820 eig_axes [1 ].axvline (x = 0 , c = "k" , lw = 1 )
789821 eig_axes [1 ].axhline (y = 0 , c = "k" , lw = 1 )
790822 eig_axes [1 ].axis ("equal" )
791- # Plot the unit circle.
792823 eig_axes [1 ].set_title ("Discrete-time Eigenvalues" , fontsize = title_fontsize )
793824 t = np .linspace (0 , 2 * np .pi , 100 )
794- eig_axes [1 ].plot (np .cos (t ), np .sin (t ), c = "tab:blue" , ls = "--" )
825+ eig_axes [1 ].plot (np .cos (t ), np .sin (t ), c = circle_color , ls = "--" )
795826 eig_axes [1 ].set_xlabel (r"$Re(\lambda)$" , fontsize = label_fontsize )
796827 eig_axes [1 ].set_ylabel (r"$Im(\lambda)$" , fontsize = label_fontsize )
797- # Plot the eigenvalues.
798- if disc_eigs is not None :
799- for idx , eig in enumerate (disc_eigs ):
800- if idx in index_modes :
801- color = main_colors [index_modes .index (idx )]
802- else :
803- color = "tab:orange"
804- ax .plot (eig .real , eig .imag , "o" , c = color , ms = ms_vals [idx ], mec = "k" )
805828
806829 # PLOT 3: Plot the continuous-time eigenvalues.
807- # Plot the complex plane axes.
808830 eig_axes [2 ].axvline (x = 0 , c = "k" , lw = 1 )
809831 eig_axes [2 ].axhline (y = 0 , c = "k" , lw = 1 )
810- # eig_axes[2].axis("equal")
832+ eig_axes [2 ].axis ("equal" )
811833 eig_axes [2 ].set_title ("Continuous-time Eigenvalues" , fontsize = title_fontsize )
812- eig_axes [2 ].set_xlabel (r"$Im(\omega)$" , fontsize = label_fontsize )
813- eig_axes [2 ].set_ylabel (r"$Re(\omega)$" , fontsize = label_fontsize )
814- eig_axes [2 ].invert_xaxis ()
815- # Plot the eigenvalues.
816- if cont_eigs is not None :
817- for idx , eig in enumerate (cont_eigs ):
818- if idx in index_modes :
819- color = main_colors [index_modes .index (idx )]
820- else :
821- color = "tab:orange"
822- ax .plot (eig .imag , eig .real , "o" , c = color , ms = ms_vals [idx ], mec = "k" )
834+ if flip_continuous_axes :
835+ eig_axes [2 ].set_xlabel (r"$Im(\omega)$" , fontsize = label_fontsize )
836+ eig_axes [2 ].set_ylabel (r"$Re(\omega)$" , fontsize = label_fontsize )
837+ eig_axes [2 ].invert_xaxis ()
838+ cont_eigs = 1j * cont_eigs .real + cont_eigs .imag
839+ else :
840+ eig_axes [2 ].set_xlabel (r"$Re(\omega)$" , fontsize = label_fontsize )
841+ eig_axes [2 ].set_ylabel (r"$Im(\omega)$" , fontsize = label_fontsize )
842+
843+ # Now plot the eigenvalues and record the colors used for each main index.
844+ mode_colors = {}
845+ for ax , eigs in zip ([eig_axes [1 ], eig_axes [2 ]], [disc_eigs , cont_eigs ]):
846+ if eigs is not None :
847+ for i , indices in enumerate (index_modes_cc ):
848+ for idx in indices :
849+ ax .plot (
850+ eigs [idx ].real ,
851+ eigs [idx ].imag ,
852+ "o" , c = main_colors [i ], ms = ms_vals [idx ], mec = "k" ,
853+ )
854+ mode_colors [idx ] = main_colors [i ]
855+ for idx in other_eigs :
856+ ax .plot (
857+ eigs [idx ].real ,
858+ eigs [idx ].imag ,
859+ "o" , c = rank_color , ms = ms_vals [idx ], mec = "k" ,
860+ )
823861
824862 # PLOTS 4-6: Plot the DMD modes.
825- if imshow_kwargs is None :
826- imshow_kwargs = {}
827- if "cmap" not in imshow_kwargs :
828- imshow_kwargs ["cmap" ] = "bwr"
863+ if x is None :
864+ x = np .arange (snapshots_shape [0 ])
829865
830866 for i , (ax , idx ) in enumerate (zip (mode_axes , index_modes )):
831- ax .set_title (
832- f"Mode { idx + 1 } " , c = main_colors [i ], fontsize = title_fontsize
833- )
834- # Plot modes in 1D.
867+ ax .set_title (f"Mode { idx + 1 } " , c = mode_colors [idx ], fontsize = title_fontsize )
868+ # Plot modes in 1-D.
835869 if len (snapshots_shape ) == 1 :
836- if x is None :
837- x = np .arange (len (lead_modes ))
838- ax .plot (x , lead_modes [:, idx ].real , c = "k" )
839- # Plot modes in 2D.
870+ ax .plot (x , lead_modes [:, idx ].real , c = mode_color )
871+ # Plot modes in 2-D.
840872 else :
873+ if y is None :
874+ y = np .arange (snapshots_shape [1 ])
875+ ygrid , xgrid = np .meshgrid (y , x )
841876 mode = lead_modes [:, idx ].reshape (* snapshots_shape , order = order )
842877 vmax = np .abs (mode .real ).max ()
843- im = ax .imshow ( mode .real , vmax = vmax , vmin = - vmax , ** imshow_kwargs )
878+ im = ax .pcolormesh ( xgrid , ygrid , mode .real , vmax = vmax , vmin = - vmax , cmap = mode_cmap )
844879 # Align the colorbar with the plotted image.
845880 divider = make_axes_locatable (ax )
846881 cax = divider .append_axes ("right" , size = "3%" , pad = 0.05 )
@@ -849,8 +884,8 @@ def plot_summary(
849884 # PLOTS 7-9: Plot the DMD mode dynamics.
850885 for i , (ax , idx ) in enumerate (zip (dynamics_axes , index_modes )):
851886 dynamics_data = lead_dynamics [idx ].real
852- ax .set_title ("Mode Dynamics" , c = main_colors [ i ], fontsize = title_fontsize )
853- ax .plot (time , dynamics_data , c = "tab:blue" )
887+ ax .set_title ("Mode Dynamics" , c = mode_colors [ idx ], fontsize = title_fontsize )
888+ ax .plot (time , dynamics_data , c = dynamics_color )
854889 ax .set_xlabel ("Time" , fontsize = label_fontsize )
855890
856891 # Re-adjust ylim if dynamics oscillations are extremely small.
0 commit comments