1111from tqdm import tqdm
1212
1313from typing import Any , Dict , Optional , List , Tuple
14- from plotnine import *
14+ import matplotlib .pyplot as plt
15+ import seaborn as sns
1516import warnings
1617
1718
@@ -637,7 +638,6 @@ def flow_select(self,
637638 target : np .ndarray ,
638639 candidate_flows : List ,
639640 max_iter : int = 100 ,
640- n_samples : int = 1000 ,
641641 plot : bool = False ,
642642 figure_size : tuple = (10 , 5 ),
643643 ) -> pd .DataFrame :
@@ -653,8 +653,6 @@ def flow_select(self,
653653 List of candidate normalizing flow specifications.
654654 max_iter: int
655655 Maximum number of iterations for the optimization.
656- n_samples: int
657- Number of samples drawn from the fitted distribution.
658656 plot: bool
659657 If True, a density plot of the actual and fitted distribution is created.
660658 figure_size: tuple
@@ -692,11 +690,11 @@ def flow_select(self,
692690 }
693691 )
694692 flow_list .append (fit_df )
695- fit_df = pd .concat (flow_list ).sort_values (by = flow_sel .loss_fn , ascending = True )
696- fit_df ["rank" ] = fit_df [flow_sel .loss_fn ].rank ().astype (int )
697- fit_df .set_index (fit_df ["rank" ], inplace = True )
698693 pbar .update (1 )
699694 pbar .set_description (f"Fitting of candidate normalizing flows completed" )
695+ fit_df = pd .concat (flow_list ).sort_values (by = flow_sel .loss_fn , ascending = True )
696+ fit_df ["rank" ] = fit_df [flow_sel .loss_fn ].rank ().astype (int )
697+ fit_df .set_index (fit_df ["rank" ], inplace = True )
700698
701699 if plot :
702700 # Select normalizing flow with the lowest loss
@@ -713,29 +711,17 @@ def flow_select(self,
713711 flow_params = torch .tensor (best_flow ["params" ][0 ]).reshape (1 , - 1 )
714712 flow_dist_sel = best_flow_sel .create_spline_flow (input_dim = 1 )
715713 _ , flow_dist_sel = best_flow_sel .replace_parameters (flow_params , flow_dist_sel )
716- flow_samples = pd .DataFrame (flow_dist_sel .sample ((n_samples ,)).squeeze ().detach ().numpy ().T )
714+ n_samples = np .max ([10000 , target .shape [0 ]])
715+ n_samples = np .where (n_samples > 500000 , 100000 , n_samples )
716+ flow_samples = pd .DataFrame (flow_dist_sel .sample ((n_samples ,)).squeeze ().detach ().numpy ().T ).values
717717
718718 # Plot actual and fitted distribution
719- flow_samples ["type" ] = f"Best-Fit: { best_flow ['NormFlow' ].values [0 ]} "
720-
721- df_actual = pd .DataFrame (target )
722- df_actual ["type" ] = "Data"
723-
724- plot_df = pd .concat ([df_actual , flow_samples ]).rename (columns = {0 : "variable" })
725-
726- print (
727- ggplot (plot_df ,
728- aes (x = "variable" ,
729- color = "type" )) +
730- geom_density (size = 1.1 ) +
731- theme_bw (base_size = 15 ) +
732- theme (figure_size = figure_size ,
733- legend_position = "right" ,
734- legend_title = element_blank (),
735- plot_title = element_text (hjust = 0.5 )) +
736- labs (title = f"Actual vs. Fitted Density" ,
737- x = "" )
738- )
719+ plt .figure (figsize = figure_size )
720+ sns .kdeplot (target .reshape (- 1 , ), label = "Actual" )
721+ sns .kdeplot (flow_samples .reshape (- 1 , ), label = f"Best-Fit: { best_flow ['NormFlow' ].values [0 ]} " )
722+ plt .legend ()
723+ plt .title ("Actual vs. Best-Fit Density" , fontweight = "bold" , fontsize = 16 )
724+ plt .show ()
739725
740726 fit_df .drop (columns = ["rank" , "params" ], inplace = True )
741727
0 commit comments