@@ -1458,39 +1458,45 @@ def plot_mode_uq(
14581458 modes = np .average (modes .reshape (d , nd // d , r ), axis = 0 )
14591459 modes_std = np .average (modes_std .reshape (d , nd // d , r ), axis = 0 )
14601460
1461+ # Define the subplot grid.
14611462 rows = 2 * int (np .ceil (modes .shape [- 1 ] / cols ))
1462- fig , axes = plt .subplots (rows , cols , figsize = figsize , dpi = dpi )
1463- avg_axes = [ax for axes_list in axes [::2 ] for ax in axes_list ]
1464- std_axes = [ax for axes_list in axes [1 ::2 ] for ax in axes_list ]
1465- avg_axes = avg_axes [:modes .shape [- 1 ]]
1466- std_axes = std_axes [:modes .shape [- 1 ]]
1467-
1468- for i , (ax_avg , ax_std , mode , mode_std ) in enumerate (
1469- zip (avg_axes , std_axes , modes .T , modes_std .T )
1470- ):
1471- ax_avg .set_title (f"Mode { i + 1 } " )
1472- ax_std .set_title ("Mode Standard Deviation" )
1473-
1463+ plt .figure (figsize = figsize , dpi = dpi )
1464+ all_inds = np .arange (rows * cols ).reshape (rows , cols )
1465+ avg_inds = all_inds [::2 ].flatten ()
1466+ std_inds = all_inds [1 ::2 ].flatten ()
1467+
1468+ for i , (mode , mode_std ) in enumerate (zip (modes .T , modes_std .T )):
1469+ # Plot the average mode.
1470+ plt .subplot (rows , cols , avg_inds [i ])
1471+ plt .title (f"Mode { i + 1 } " )
14741472 if len (modes_shape ) == 1 :
14751473 # Plot modes in 1-D.
1476- ax_avg .plot (x , mode .real , c = "tab:blue" )
1477- ax_std .plot (x , mode_std , c = "tab:red" )
1474+ plt .plot (x , mode .real , c = "tab:blue" )
14781475 else :
14791476 # Plot modes in 2-D.
1480- im_avg = ax_avg .pcolormesh (
1477+ plt .pcolormesh (
14811478 xgrid ,
14821479 ygrid ,
14831480 mode .reshape (* modes_shape , order = order ).real ,
14841481 cmap = "viridis" ,
14851482 )
1486- im_std = ax_std .pcolormesh (
1483+ plt .colorbar ()
1484+
1485+ # Plot the mode standard deviation.
1486+ plt .subplot (rows , cols , std_inds [i ])
1487+ plt .title ("Mode Standard Deviation" )
1488+ if len (modes_shape ) == 1 :
1489+ # Plot modes in 1-D.
1490+ plt .plot (x , mode_std , c = "tab:red" )
1491+ else :
1492+ # Plot modes in 2-D.
1493+ plt .pcolormesh (
14871494 xgrid ,
14881495 ygrid ,
14891496 mode_std .reshape (* modes_shape , order = order ),
14901497 cmap = "inferno" ,
14911498 )
1492- fig .colorbar (im_avg , ax = ax_avg )
1493- fig .colorbar (im_std , ax = ax_std )
1499+ plt .colorbar ()
14941500
14951501 plt .suptitle ("DMD Modes" )
14961502 plt .tight_layout ()
0 commit comments