1+ import matplotlib .pyplot as plt
2+
13import numpy as np
24import pandas as pd
35import seaborn as sns
810
911
1012def plot_posterior_2d (
11- post_samples : dict [ str , np . ndarray ] | np .ndarray ,
12- prior_samples : dict [ str , np .ndarray ] | np . ndarray ,
13+ post_samples : np .ndarray ,
14+ prior_samples : np .ndarray = None ,
1315 prior = None ,
14- param_names : list = None ,
16+ variable_names : list = None ,
17+ true_params : np .ndarray = None ,
1518 height : int = 3 ,
1619 label_fontsize : int = 14 ,
1720 legend_fontsize : int = 16 ,
@@ -24,15 +27,17 @@ def plot_posterior_2d(
2427) -> sns .PairGrid :
2528 """Generates a bivariate pairplot given posterior draws and optional prior or prior draws.
2629
27- posterior_draws : np.ndarray of shape (n_post_draws, n_params)
30+ post_samples : np.ndarray of shape (n_post_draws, n_params)
2831 The posterior draws obtained for a SINGLE observed data set.
29- prior : bayesflow.forward_inference.Prior instance or None, optional, default: None
30- The optional prior object having an input-output signature as given by ayesflow.forward_inference.Prior
31- prior_draws : np.ndarray of shape (n_prior_draws, n_params) or None, optonal (default: None)
32- The optional prior draws obtained from the prior. If both prior and prior_draws are provided, prior_draws
32+ prior_samples : np.ndarray of shape (n_prior_draws, n_params) or None, optional (default: None)
33+ The optional prior samples obtained from the prior. If both prior and prior_samples are provided, prior_samples
3334 will be used.
34- param_names : list or None, optional, default: None
35+ prior : bayesflow.forward_inference.Prior instance or None, optional, default: None
36+ The optional prior object having an input-output signature as given by bayesflow.forward_inference.Prior
37+ variable_names : list or None, optional, default: None
3538 The parameter names for nice plot titles. Inferred if None
39+ true_params : np.ndarray of shape (n_params,) or None, optional, default: None
40+ The true parameter values to be plotted on the diagonal.
3641 height : float, optional, default: 3
3742 The height of the pairplot
3843 label_fontsize : int, optional, default: 14
@@ -41,7 +46,7 @@ def plot_posterior_2d(
4146 The font size of the legend text
4247 tick_fontsize : int, optional, default: 12
4348 The font size of the axis ticklabels
44- post_color : str, optional, default: '#8f2727 '
49+ post_color : str, optional, default: '#132a70 '
4550 The color for the posterior histograms and KDEs
4651 priors_color : str, optional, default: gray
4752 The color for the optional prior histograms and KDEs
@@ -64,7 +69,10 @@ def plot_posterior_2d(
6469 assert (len (post_samples .shape )) == 2 , "Shape of `posterior_samples` for a single data set should be 2 dimensional!"
6570
6671 # Plot posterior first
67- g = plot_samples_2d (post_samples , context = "\\ theta" , param_names = param_names , render = False , height = height , ** kwargs )
72+ context = ""
73+ g = plot_samples_2d (
74+ post_samples , context = context , variable_names = variable_names , render = False , height = height , ** kwargs
75+ )
6876
6977 # Obtain n_draws and n_params
7078 n_draws , n_params = post_samples .shape
@@ -73,34 +81,54 @@ def plot_posterior_2d(
7381 if prior is not None and prior_samples is None :
7482 draws = prior (n_draws )
7583 if isinstance (draws , dict ):
76- prior_draws = draws ["prior_draws" ]
84+ prior_samples = draws ["prior_draws" ]
7785 else :
78- prior_draws = draws
86+ prior_samples = draws
87+ elif prior_samples is not None :
88+ # trim to the same number of draws as posterior
89+ prior_samples = prior_samples [:n_draws ]
7990
8091 # Attempt to determine parameter names
81- if param_names is None :
92+ if variable_names is None :
8293 if hasattr (prior , "param_names" ):
83- if prior .param_names is not None :
84- param_names = prior .param_names
94+ if prior .variable_names is not None :
95+ variable_names = prior .variable_names
8596 else :
86- param_names = [f"$\\ theta_{{{ i } }}$" for i in range (1 , n_params + 1 )]
97+ variable_names = [f"{ context } $\\ theta_{{{ i } }}$" for i in range (1 , n_params + 1 )]
8798 else :
88- param_names = [f"$\\ theta_{{{ i } }}$" for i in range (1 , n_params + 1 )]
99+ variable_names = [f"{ context } $\\ theta_{{{ i } }}$" for i in range (1 , n_params + 1 )]
100+ else :
101+ variable_names = [f"{ context } { p } " for p in variable_names ]
89102
90103 # Add prior, if given
91- if prior_draws is not None :
92- prior_draws_df = pd .DataFrame (prior_draws , columns = param_names )
93- g .data = prior_draws_df
104+ if prior_samples is not None :
105+ prior_samples_df = pd .DataFrame (prior_samples , columns = variable_names )
106+ g .data = prior_samples_df
94107 g .map_diag (sns .histplot , fill = True , color = prior_color , alpha = prior_alpha , kde = True , zorder = - 1 )
95108 g .map_lower (sns .kdeplot , fill = True , color = prior_color , alpha = prior_alpha , zorder = - 1 )
96109
110+ # Add true parameters
111+ if true_params is not None :
112+ # Custom function to plot true_params on the diagonal
113+ def plot_true_params (x , ** kwargs ):
114+ param = x .iloc [0 ] # Get the single true value for the diagonal
115+ plt .axvline (param , color = "black" , linestyle = "--" ) # Add vertical line
116+
117+ # only plot on the diagonal a vertical line for the true parameter
118+ g .data = pd .DataFrame (true_params [np .newaxis ], columns = variable_names )
119+ g .map_diag (plot_true_params )
120+
97121 # Add legend, if prior also given
98- if prior_draws is not None or prior is not None :
122+ if prior_samples is not None or prior is not None :
99123 handles = [
100124 Line2D (xdata = [], ydata = [], color = post_color , lw = 3 , alpha = post_alpha ),
101125 Line2D (xdata = [], ydata = [], color = prior_color , lw = 3 , alpha = prior_alpha ),
102126 ]
103- g .legend (handles , ["Posterior" , "Prior" ], fontsize = legend_fontsize , loc = "center right" )
127+ handles_names = ["Posterior" , "Prior" ]
128+ if true_params is not None :
129+ handles .append (Line2D (xdata = [], ydata = [], color = "black" , lw = 3 , linestyle = "--" ))
130+ handles_names .append ("True Parameter" )
131+ plt .legend (handles = handles , labels = handles_names , fontsize = legend_fontsize , loc = "center right" )
104132
105133 n_row , n_col = g .axes .shape
106134
@@ -115,9 +143,9 @@ def plot_posterior_2d(
115143 g .axes [i , j ].tick_params (axis = "both" , which = "minor" , labelsize = tick_fontsize )
116144
117145 # Add nice labels
118- for i , param_name in enumerate (param_names ):
146+ for i , param_name in enumerate (variable_names ):
119147 g .axes [i , 0 ].set_ylabel (param_name , fontsize = label_fontsize )
120- g .axes [len (param_names ) - 1 , i ].set_xlabel (param_name , fontsize = label_fontsize )
148+ g .axes [len (variable_names ) - 1 , i ].set_xlabel (param_name , fontsize = label_fontsize )
121149
122150 # Add grids
123151 for i in range (n_params ):
0 commit comments