@@ -20,14 +20,12 @@ def pairs_posterior(
2020 variable_keys : Sequence [str ] = None ,
2121 variable_names : Sequence [str ] = None ,
2222 height : int = 3 ,
23+ post_color : str | tuple = "#132a70" ,
24+ prior_color : str | tuple = "gray" ,
25+ alpha = 0.9 ,
2326 label_fontsize : int = 14 ,
2427 tick_fontsize : int = 12 ,
25- # arguments related to priors which is currently unused
26- # legend_fontsize: int = 16,
27- # post_color: str | tuple = "#132a70",
28- # prior_color: str | tuple = "gray",
29- # post_alpha: float = 0.9,
30- # prior_alpha: float = 0.7,
28+ legend_fontsize : int = 14 ,
3129 ** kwargs ,
3230) -> sns .PairGrid :
3331 """Generates a bivariate pair plot given posterior draws and optional prior or prior draws.
@@ -57,10 +55,12 @@ def pairs_posterior(
5755 The color for the posterior histograms and KDEs
5856 priors_color : str, optional, default: gray
5957 The color for the optional prior histograms and KDEs
60- post_alpha : float in [0, 1], optonal , default: 0.9
58+ post_alpha : float in [0, 1], optional , default: 0.9
6159 The opacity of the posterior plots
62- prior_alpha : float in [0, 1], optonal , default: 0.7
60+ prior_alpha : float in [0, 1], optional , default: 0.7
6361 The opacity of the prior plots
62+ **kwargs : dict, optional, default: {}
63+ Further optional keyword arguments propagated to `_pairs_samples`
6464
6565 Returns
6666 -------
@@ -75,6 +75,7 @@ def pairs_posterior(
7575 plot_data = dicts_to_arrays (
7676 estimates = estimates ,
7777 targets = targets ,
78+ priors = priors ,
7879 dataset_ids = dataset_id ,
7980 variable_keys = variable_keys ,
8081 variable_names = variable_names ,
@@ -90,52 +91,33 @@ def pairs_posterior(
9091 g = _pairs_samples (
9192 plot_data = plot_data ,
9293 height = height ,
94+ color = post_color ,
95+ color2 = prior_color ,
96+ alpha = alpha ,
9397 label_fontsize = label_fontsize ,
9498 tick_fontsize = tick_fontsize ,
99+ legend_fontsize = legend_fontsize ,
95100 ** kwargs ,
96101 )
97102
98- # add priors
99- if priors is not None :
100- # TODO: integrate priors into plot_data and then use
101- # proper coloring of posterior vs. prior using the hue argument in PairGrid
102- raise ValueError ("Plotting prior samples is not yet implemented." )
103-
104- """
105- # this is currently not working as expected as it doesn't show the off diagonal plots
106- prior_samples_df = pd.DataFrame(priors, columns=plot_data["variable_names"])
107- g.data = prior_samples_df
108- g.map_diag(sns.histplot, fill=True, color=prior_color, alpha=prior_alpha, kde=True, zorder=-1)
109- g.map_lower(sns.kdeplot, fill=True, color=prior_color, alpha=prior_alpha, zorder=-1)
110-
111- # Add legend to differentiate between prior and posterior
112- handles = [
113- Line2D(xdata=[], ydata=[], color=post_color, lw=3, alpha=post_alpha),
114- Line2D(xdata=[], ydata=[], color=prior_color, lw=3, alpha=prior_alpha),
115- ]
116- handles_names = ["Posterior", "Prior"]
117- if targets is not None:
118- handles.append(Line2D(xdata=[], ydata=[], color="black", lw=3, linestyle="--"))
119- handles_names.append("True Parameter")
120- plt.legend(handles=handles, labels=handles_names, fontsize=legend_fontsize, loc="center right")
121- """
122-
123- # add true parameters
124- if plot_data ["targets" ] is not None :
125- # TODO: also add true parameters to the off diagonal plots?
126-
127- # drop dataset axis if it is still present but of length 1
128- targets_shape = plot_data ["targets" ].shape
129- if len (targets_shape ) == 2 and targets_shape [0 ] == 1 :
130- plot_data ["targets" ] = np .squeeze (plot_data ["targets" ], axis = 0 )
131-
132- # Custom function to plot true parameters on the diagonal
133- def plot_true_params (x , ** kwargs ):
134- param = x .iloc [0 ] # Get the single true value for the diagonal
135- plt .axvline (param , color = "black" , linestyle = "--" ) # Add vertical line
136-
137- # only plot on the diagonal a vertical line for the true parameter
138- g .data = pd .DataFrame (plot_data ["targets" ][np .newaxis ], columns = plot_data ["variable_names" ])
103+ targets = plot_data .get ("targets" )
104+ if targets is not None :
105+ # Ensure targets is at least 2D
106+ if targets .ndim == 1 :
107+ targets = np .atleast_2d (targets )
108+
109+ # Create DataFrame with variable names as columns
110+ g .data = pd .DataFrame (targets , columns = targets .variable_names )
111+ g .data ["_source" ] = "True Parameter"
139112 g .map_diag (plot_true_params )
140113
141114 return g
115+
116+
117+ def plot_true_params (x , ** kwargs ):
118+ """Custom function to plot true parameters on the diagonal."""
119+
120+ # hue needs to be added to handle the case of plotting both posterior and prior
121+ param = x .iloc [0 ] # Get the single true value for the diagonal
122+ # only plot on the diagonal a vertical line for the true parameter
123+ plt .axvline (param , color = "black" , linestyle = "--" )
0 commit comments