@@ -1394,6 +1394,7 @@ def forecast(self, t):
13941394
13951395 def plot_mode_uq (
13961396 self ,
1397+ * ,
13971398 x = None ,
13981399 y = None ,
13991400 d = 1 ,
@@ -1402,6 +1403,8 @@ def plot_mode_uq(
14021403 cols = 4 ,
14031404 figsize = None ,
14041405 dpi = None ,
1406+ plot_modes = None ,
1407+ plot_complex_pairs = True ,
14051408 ):
14061409 """
14071410 Plot BOP-DMD modes alongside their standard deviations.
@@ -1431,10 +1434,33 @@ def plot_mode_uq(
14311434 :type figsize: iterable
14321435 :param dpi: Figure resolution.
14331436 :type dpi: int
1437+ :param plot_modes: Number of leading modes to plot, or the indices of
1438+ the modes to plot. If `None`, then all available modes are plotted.
1439+ Note that if this parameter is given as a list of indices, it will
1440+ override the `plot_complex_pair` parameter.
1441+ :type plot_modes: int or iterable
1442+ :param plot_complex_pairs: Whether or not to omit one of the modes that
1443+ correspond with a complex conjugate pair of eigenvalues.
1444+ :type plot_complex_pairs: bool
14341445 """
14351446 if self .modes_std is None :
14361447 raise ValueError ("No UQ metrics to plot." )
14371448
1449+ # Get the indices of the modes to plot.
1450+ nd , r = self .modes .shape
1451+ if plot_modes is None or isinstance (plot_modes , int ):
1452+ mode_indices = np .arange (r )
1453+ if plot_complex_pairs :
1454+ if r % 2 == 0 :
1455+ mode_indices = mode_indices [::2 ]
1456+ else :
1457+ mode_indices = np .concatenate ([(0 ,), mode_indices [1 ::2 ]])
1458+ if isinstance (plot_modes , int ):
1459+ mode_indices = mode_indices [:plot_modes ]
1460+ else :
1461+ mode_indices = plot_modes
1462+ plot_complex_pairs = True
1463+
14381464 # By default, modes_shape is the flattened space dimension.
14391465 if modes_shape is None :
14401466 modes_shape = (len (self .snapshots ) // d ,)
@@ -1454,21 +1480,36 @@ def plot_mode_uq(
14541480
14551481 # Collapse the results across time-delays.
14561482 if d > 1 :
1457- nd , r = modes .shape
14581483 modes = np .average (modes .reshape (d , nd // d , r ), axis = 0 )
14591484 modes_std = np .average (modes_std .reshape (d , nd // d , r ), axis = 0 )
14601485
14611486 # Define the subplot grid.
1462- rows = 2 * int (np .ceil (modes .shape [- 1 ] / cols ))
1463- plt .figure (figsize = figsize , dpi = dpi )
1487+ # Compute the number of subplot rows given the number of columns.
1488+ rows = 2 * int (np .ceil (len (mode_indices ) / cols ))
1489+
1490+ # Compute a grid of all subplot indices.
14641491 all_inds = np .arange (rows * cols ).reshape (rows , cols )
1492+
1493+ # Get the subplot indices at which the mode averages will be plotted.
1494+ # Mode averages are plotted on the 1st, 3rd, 5th, ... rows of the plot.
14651495 avg_inds = all_inds [::2 ].flatten ()
1496+
1497+ # Get the subplot indices at which the mode stds will be plotted.
1498+ # Mode stds are plotted on the 2nd, 4th, 6th, ... rows of the plot.
14661499 std_inds = all_inds [1 ::2 ].flatten ()
14671500
1468- for i , (mode , mode_std ) in enumerate (zip (modes .T , modes_std .T )):
1501+ plt .figure (figsize = figsize , dpi = dpi )
1502+
1503+ for i , idx in enumerate (mode_indices ):
1504+ mode = modes [:, idx ]
1505+ mode_std = modes_std [:, idx ]
1506+
14691507 # Plot the average mode.
14701508 plt .subplot (rows , cols , avg_inds [i ] + 1 )
1471- plt .title (f"Mode { i + 1 } " )
1509+ if plot_complex_pairs :
1510+ plt .title (f"Mode { idx + 1 } " )
1511+ if not plot_complex_pairs :
1512+ plt .title (f"Mode { idx + 1 } , { idx + 2 } " )
14721513 if len (modes_shape ) == 1 :
14731514 # Plot modes in 1-D.
14741515 plt .plot (x , mode .real , c = "tab:blue" )
0 commit comments