88import matplotlib .pyplot as plt
99import numpy as np
1010from mpl_toolkits .axes_grid1 import make_axes_locatable
11+ from matplotlib .patches import Patch
1112
1213from pydmd import MrDMD
1314
@@ -535,49 +536,50 @@ def plot_snapshots_2D(
535536def plot_summary (
536537 dmd ,
537538 * ,
539+ x = None ,
538540 t = None ,
539541 d = 1 ,
540542 continuous = False ,
541543 snapshots_shape = None ,
542- index_modes = None ,
544+ index_modes = ( 0 , 1 , 2 ) ,
543545 filename = None ,
544546 order = "C" ,
545547 figsize = (12 , 8 ),
546548 dpi = 200 ,
547549 tight_layout_kwargs = None ,
548550 main_colors = ("r" , "b" , "g" ),
549- mode_color = "k" ,
550- mode_cmap = "bwr" ,
551- sval_color = "tab:orange" ,
552- dynamics_color = "tab:blue" ,
551+ imshow_kwargs = None ,
553552 sval_ms = 8 ,
554- max_eig_ms = 10 ,
553+ max_eig_ms = 12 ,
555554 max_sval_plot = 50 ,
556555 title_fontsize = 14 ,
557556 label_fontsize = 12 ,
558557 plot_semilogy = False ,
559- remove_cmap_ticks = False ,
560558):
561559 """
562560 Generate a 3 x 3 summarizing plot that contains the following components:
563561 - the singular value spectrum of the data
564562 - the discrete-time and continuous-time DMD eigenvalues
565- - the three DMD modes specified by the `index_modes` parameter
566- - the dynamics corresponding with each plotted mode
567- Eigenvalues, modes, and dynamics are ordered according to the magnitude of
568- their corresponding amplitude value. Singular values and eigenvalues that
569- are associated with plotted modes and dynamics are also highlighted.
570-
571- :param dmd: DMD instance.
563+ - the DMD modes specified by the `index_modes` parameter
564+ - the time dynamics that correspond with each plotted mode
565+ The number of singular values used for the DMD fit are highlighted.
566+ All eigenvalues, modes, and dynamics are sorted according to the magnitude
567+ of their corresponding amplitude value, i.e. their significance in the fit.
568+ Correspondence between eigenvalues, modes, and dynamics is indicated via
569+ color coordination.
570+
571+ :param dmd: fitted DMD instance.
572572 :type dmd: pydmd.DMDBase
573- :param t: the input time vector or uniform time-step between snapshots.
574- Note that the time-step must be accurate in order to visualize accurate
575- discrete and continuous-time eigenvalues, as well as accurate times of
576- the dynamics. For non-`BOPDMD` models, times of data collection must be
577- uniformly-spaced, and if not provided, TimeDict information stored in
578- the provided DMD instance is used instead. This parameter is ignored if
579- an instance of `BOPDMD` is provided.
580- :type t: {numpy.ndarray, list} or {int, float}
573+ :param x: The points in space where the data has been collected. Note that
574+ this parameter is currently only used for plotting modes that are 1-D.
575+ :type x: np.ndarray or iterable
576+ :param t: The times of data collection, or the time-step between snapshots.
577+ Note that time information must be accurate in order to accurately
578+ visualize eigenvalues and times of the dynamics. For non-`BOPDMD`
579+ models, the entries of t are assumed to be uniformly-spaced, and if
580+ not provided, TimeDict information is used. This parameter is ignored
581+ if an instance of `BOPDMD` is provided.
582+ :type t: {numpy.ndarray, iterable} or {int, float}
581583 :param d: Number of delays applied to the data passed to the DMD instance.
582584 If `d` is greater than 1, then each plotted mode will be the average
583585 mode taken across all `d` delays.
@@ -591,10 +593,11 @@ def plot_summary(
591593 :param snapshots_shape: Shape of the snapshots. If not provided, the shape
592594 of the snapshots and modes is assumed to be the flattened space dim of
593595 the snapshot data.
594- :type snapshots_shape: tuple(int, int)
595- :param index_modes: A list of the indices of the modes to plot. By default,
596- the first three leading modes are plotted.
597- :type index_modes: list
596+ :type snapshots_shape: iterable
597+ :param index_modes: Indices of the modes to plot after they have been
598+ sorted based on significance. At most three may be provided.
599+ By default, the first three leading modes are plotted.
600+ :type index_modes: iterable
598601 :param filename: If specified, the plot is saved at `filename`.
599602 :type filename: str
600603 :param order: Read the elements of snapshots using this index order,
@@ -610,27 +613,25 @@ def plot_summary(
610613 order if a is Fortran contiguous in memory, C-like order otherwise.
611614 "C" is used by default.
612615 :type order: {"C", "F", "A"}
613- :param figsize: Tuple in inches defining the figure size .
614- :type figsize: tuple(int, int)
616+ :param figsize: Width, height in inches.
617+ :type figsize: iterable
615618 :param dpi: Figure resolution.
616619 :type dpi: int
617620 :param tight_layout_kwargs: Optional dictionary of
618- `matplotlib.pyplot.tight_layout() ` parameters.
621+ `matplotlib.pyplot.tight_layout` parameters.
619622 :type tight_layout_kwargs: dict
620- :param main_colors: Tuple of strings defining the colors used to denote
621- eigenvalue, mode, dynamics associations.
622- :type main_colors: tuple(str, str, str)
623- :param mode_color: Color used to plot the modes, if modes are 1D.
624- :type mode_color: str
625- :param mode_cmap: Colormap used to plot the modes, if modes are 2D.
626- :type mode_cmap: str
627- :param dynamics_color: Color used to plot the dynamics.
628- :type dynamics_color: str
623+ :param main_colors: Strings defining the colors used to denote eigenvalue,
624+ mode, dynamics associations.
625+ :type main_colors: iterable
626+ :param imshow_kwargs: Optional dictionary of `matplotlib.pyplot.imshow`
627+ parameters. Use this dictionary to re-define the parameters of 2-D
628+ mode plots.
629+ :type imshow_kwargs: dict
629630 :param sval_ms: Marker size of all singular values.
630631 :type sval_ms: int
631632 :param max_eig_ms: Marker size of the most prominent eigenvalue. The marker
632633 sizes of all other eigenvalues are then scaled according to eigenvalue
633- prominence .
634+ significance .
634635 :type max_eig_ms: int
635636 :param max_sval_plot: Maximum number of singular values to plot.
636637 :type max_sval_plot: int
@@ -641,9 +642,6 @@ def plot_summary(
641642 :param plot_semilogy: Whether or not to plot the singular values on a
642643 semilogy plot. If `True`, a semilogy plot is used.
643644 :type plot_semilogy: bool
644- :param remove_cmap_ticks: Whether or not to include the ticks on 2D mode
645- plots. If `True`, ticks are removed from all 2D mode plots.
646- :type remove_cmap_ticks: bool
647645 """
648646
649647 # This plotting method is inappropriate for plotting HAVOK results.
@@ -660,28 +658,27 @@ def plot_summary(
660658 # By default, snapshots_shape is the flattened space dimension.
661659 if snapshots_shape is None :
662660 snapshots_shape = (len (dmd .snapshots ),)
663- # Only 2D tuples are admissible for snapshots_shape .
664- elif not isinstance ( snapshots_shape , tuple ) or len (snapshots_shape ) != 2 :
665- raise ValueError ("snapshots_shape must be None or a 2D tuple ." )
661+ # If provided, snapshots_shape must contain 2 entires .
662+ elif len (snapshots_shape ) != 2 :
663+ raise ValueError ("snapshots_shape must be None or 2D ." )
666664
667665 # Get the actual rank used for the DMD fit.
668666 rank = len (dmd .eigs )
669667
670668 # Override index_modes if there are less than 3 modes available.
671669 if rank < 3 :
672670 warnings .warn (
673- "Provided dmd model has less than 3 modes."
674- "Plotting all available modes."
671+ "Provided DMD model has less than 3 modes."
672+ "Plotting all available modes... "
675673 )
676- index_modes = list (range (rank ))
677- # By default, we plot the 3 leading modes and their dynamics.
678- elif index_modes is None :
679- index_modes = list (range (3 ))
680- # index_modes was provided - check its type and its length.
681- elif not isinstance (index_modes , list ) or len (index_modes ) > 3 :
682- raise ValueError ("index_modes must be a list of length at most 3." )
674+ index_modes = np .arange (rank )
675+
676+ # Check the length of index_modes.
677+ if len (index_modes ) > 3 :
678+ raise ValueError ("index_modes must have a length of at most 3." )
679+
683680 # Indices cannot go past the total number of available or plottable modes.
684- elif np .any (np .array (index_modes ) >= min (rank , max_sval_plot )):
681+ if np .any (np .array (index_modes ) >= min (rank , max_sval_plot )):
685682 raise ValueError (
686683 f"Cannot view past mode { min (rank , max_sval_plot )} ."
687684 )
@@ -694,6 +691,8 @@ def plot_summary(
694691 lead_amplitudes = np .abs (dmd .amplitudes [mode_order ])
695692
696693 # Get time information for eigenvalue conversions.
694+ # The decisions that we make here depend on if we're dealing
695+ # with a BOPDMD model or any other type of DMD model.
697696 if isinstance (dmd , BOPDMD ) or (
698697 isinstance (dmd , PrePostProcessingDMD )
699698 and isinstance (dmd .pre_post_processed_dmd , BOPDMD )
@@ -717,7 +716,7 @@ def plot_summary(
717716 if isinstance (t , (int , float )):
718717 time = np .arange (dmd .snapshots .shape [- 1 ]) * t
719718 dt = t
720- elif isinstance ( t , ( np . ndarray , list )) :
719+ elif t is not None :
721720 # Note: assumes uniform spacing in the provided time vector.
722721 time = np .squeeze (np .array (t ))
723722 dt = time [1 ] - time [0 ]
@@ -759,101 +758,103 @@ def plot_summary(
759758 s = np .linalg .svd (snp , full_matrices = False , compute_uv = False )
760759 # Compute the percent of data variance captured by each singular value.
761760 s_var = s * (100 / np .sum (s ))
761+ s_var = s_var [:max_sval_plot ]
762762
763763 # Generate the summarizing plot.
764764 fig , (eig_axes , mode_axes , dynamics_axes ) = plt .subplots (
765765 3 , 3 , figsize = figsize , dpi = dpi
766766 )
767767
768768 # PLOT 1: Plot the singular value spectrum.
769- s_var_plot = s_var [:max_sval_plot ]
770769 eig_axes [0 ].set_title ("Singular Values" , fontsize = title_fontsize )
771770 eig_axes [0 ].set_ylabel ("% variance" , fontsize = label_fontsize )
772- s_t = np .arange (len (s_var_plot )) + 1
773- eig_axes [0 ].plot (s_t , s_var_plot , "o" , c = "gray" , ms = sval_ms , mec = "k" )
771+ s_t = np .arange (len (s_var )) + 1
772+ eig_axes [0 ].plot (s_t , s_var , "o" , c = "gray" , ms = sval_ms , mec = "k" )
774773 eig_axes [0 ].plot (
775- s_t [:rank ], s_var_plot [:rank ], "o" , c = sval_color , ms = sval_ms , mec = "k"
774+ s_t [:rank ], s_var [:rank ], "o" , c = "tab:orange" , ms = sval_ms , mec = "k"
775+ )
776+ eig_axes [0 ].legend (
777+ handles = [Patch (facecolor = "tab:orange" , label = "Rank of fit" )]
776778 )
777-
778- # for i, idx in enumerate(index_modes):
779- # eig_axes[0].plot(
780- # s_t[idx],
781- # s_var_plot[idx],
782- # "o",
783- # c=main_colors[i],
784- # ms=sval_ms,
785- # mec="k",
786- # )
787-
788779 if plot_semilogy :
789780 eig_axes [0 ].semilogy ()
790781
791782 # PLOTS 2-3: Plot the eigenvalues (discrete-time and continuous-time).
792-
793- # # Scale marker sizes to reflect the amount of variance captured.
794- # ms_vals = max_eig_ms * np.sqrt(s_var / s_var[0])
795-
796783 # Scale marker sizes to reflect their associated amplitude.
797784 ms_vals = max_eig_ms * np .sqrt (lead_amplitudes / lead_amplitudes [0 ])
798785
799- for i , (ax , eigs ) in enumerate (zip (eig_axes [1 :], [disc_eigs , cont_eigs ])):
800- # Plot the complex plane axes.
801- ax .axvline (x = 0 , c = "k" , lw = 1 )
802- ax .axhline (y = 0 , c = "k" , lw = 1 )
803- ax .axis ("equal" )
804- # PLOT 2: Plot the discrete-time eigenvalues on the unit circle.
805- if i == 0 :
806- ax .set_title ("Discrete-time Eigenvalues" , fontsize = title_fontsize )
807- t = np .linspace (0 , 2 * np .pi , 100 )
808- ax .plot (np .cos (t ), np .sin (t ), c = "tab:blue" , ls = "--" )
809- ax .set_xlabel (r"$Re(\lambda)$" , fontsize = label_fontsize )
810- ax .set_ylabel (r"$Im(\lambda)$" , fontsize = label_fontsize )
811- # PLOT 3: Plot the continuous-time eigenvalues.
812- else :
813- ax .set_title ("Continuous-time Eigenvalues" , fontsize = title_fontsize )
814- ax .set_xlabel (r"$Im(\omega)$" , fontsize = label_fontsize )
815- ax .set_ylabel (r"$Re(\omega)$" , fontsize = label_fontsize )
816- # Plot the eigenvalues (discrete or continuous).
817- if eigs is not None :
818- for idx , eig in enumerate (eigs ):
819- if idx in index_modes :
820- color = main_colors [index_modes .index (idx )]
821- else :
822- color = "gray"
823- if i == 0 :
824- ax .plot (eig .real , eig .imag , "o" , c = color , ms = ms_vals [idx ])
825- else :
826- ax .plot (eig .imag , eig .real , "o" , c = color , ms = ms_vals [idx ])
786+ # PLOT 2: Plot the discrete-time eigenvalues on the unit circle.
787+ # Plot the complex plane axes.
788+ eig_axes [1 ].axvline (x = 0 , c = "k" , lw = 1 )
789+ eig_axes [1 ].axhline (y = 0 , c = "k" , lw = 1 )
790+ eig_axes [1 ].axis ("equal" )
791+ # Plot the unit circle.
792+ eig_axes [1 ].set_title ("Discrete-time Eigenvalues" , fontsize = title_fontsize )
793+ t = np .linspace (0 , 2 * np .pi , 100 )
794+ eig_axes [1 ].plot (np .cos (t ), np .sin (t ), c = "tab:blue" , ls = "--" )
795+ eig_axes [1 ].set_xlabel (r"$Re(\lambda)$" , fontsize = label_fontsize )
796+ eig_axes [1 ].set_ylabel (r"$Im(\lambda)$" , fontsize = label_fontsize )
797+ # Plot the eigenvalues.
798+ if disc_eigs is not None :
799+ for idx , eig in enumerate (disc_eigs ):
800+ if idx in index_modes :
801+ color = main_colors [index_modes .index (idx )]
802+ else :
803+ color = "tab:orange"
804+ ax .plot (eig .real , eig .imag , "o" , c = color , ms = ms_vals [idx ], mec = "k" )
805+
806+ # PLOT 3: Plot the continuous-time eigenvalues.
807+ # Plot the complex plane axes.
808+ eig_axes [2 ].axvline (x = 0 , c = "k" , lw = 1 )
809+ eig_axes [2 ].axhline (y = 0 , c = "k" , lw = 1 )
810+ # eig_axes[2].axis("equal")
811+ eig_axes [2 ].set_title ("Continuous-time Eigenvalues" , fontsize = title_fontsize )
812+ eig_axes [2 ].set_xlabel (r"$Im(\omega)$" , fontsize = label_fontsize )
813+ eig_axes [2 ].set_ylabel (r"$Re(\omega)$" , fontsize = label_fontsize )
814+ eig_axes [2 ].invert_xaxis ()
815+ # Plot the eigenvalues.
816+ if cont_eigs is not None :
817+ for idx , eig in enumerate (cont_eigs ):
818+ if idx in index_modes :
819+ color = main_colors [index_modes .index (idx )]
820+ else :
821+ color = "tab:orange"
822+ ax .plot (eig .imag , eig .real , "o" , c = color , ms = ms_vals [idx ], mec = "k" )
827823
828824 # PLOTS 4-6: Plot the DMD modes.
825+ if imshow_kwargs is None :
826+ imshow_kwargs = {}
827+ if "cmap" not in imshow_kwargs :
828+ imshow_kwargs ["cmap" ] = "bwr"
829+
829830 for i , (ax , idx ) in enumerate (zip (mode_axes , index_modes )):
830831 ax .set_title (
831832 f"Mode { idx + 1 } " , c = main_colors [i ], fontsize = title_fontsize
832833 )
833834 # Plot modes in 1D.
834835 if len (snapshots_shape ) == 1 :
835- ax .plot (lead_modes [:, idx ].real , c = mode_color )
836+ if x is None :
837+ x = np .arange (len (lead_modes ))
838+ ax .plot (x , lead_modes [:, idx ].real , c = "k" )
836839 # Plot modes in 2D.
837840 else :
838841 mode = lead_modes [:, idx ].reshape (* snapshots_shape , order = order )
839842 vmax = np .abs (mode .real ).max ()
840- im = ax .imshow (mode .real , vmax = vmax , vmin = - vmax , cmap = mode_cmap )
843+ im = ax .imshow (mode .real , vmax = vmax , vmin = - vmax , ** imshow_kwargs )
841844 # Align the colorbar with the plotted image.
842845 divider = make_axes_locatable (ax )
843846 cax = divider .append_axes ("right" , size = "3%" , pad = 0.05 )
844847 fig .colorbar (im , cax = cax )
845- if remove_cmap_ticks :
846- ax .set_xticks ([])
847- ax .set_yticks ([])
848848
849849 # PLOTS 7-9: Plot the DMD mode dynamics.
850850 for i , (ax , idx ) in enumerate (zip (dynamics_axes , index_modes )):
851851 dynamics_data = lead_dynamics [idx ].real
852852 ax .set_title ("Mode Dynamics" , c = main_colors [i ], fontsize = title_fontsize )
853- ax .plot (time , dynamics_data , c = dynamics_color )
853+ ax .plot (time , dynamics_data , c = "tab:blue" )
854854 ax .set_xlabel ("Time" , fontsize = label_fontsize )
855- dynamics_range = dynamics_data . max () - dynamics_data . min ()
855+
856856 # Re-adjust ylim if dynamics oscillations are extremely small.
857+ dynamics_range = dynamics_data .max () - dynamics_data .min ()
857858 if dynamics_range / np .abs (np .average (dynamics_data )) < 1e-4 :
858859 ax .set_ylim (np .sort ([0.0 , 2 * np .average (dynamics_data )]))
859860
0 commit comments