88
99from bayesflow .utils import logging
1010from bayesflow .utils .dict_utils import dicts_to_arrays
11+ from bayesflow .utils .plot_utils import create_legends
1112
1213
1314def pairs_samples (
@@ -17,8 +18,10 @@ def pairs_samples(
1718 height : float = 2.5 ,
1819 color : str | tuple = "#132a70" ,
1920 alpha : float = 0.9 ,
21+ label : str = "Posterior" ,
2022 label_fontsize : int = 14 ,
2123 tick_fontsize : int = 12 ,
24+ show_single_legend : bool = False ,
2225 ** kwargs ,
2326) -> sns .PairGrid :
2427 """
@@ -37,13 +40,18 @@ def pairs_samples(
3740 height : float, optional, default: 2.5
3841 The height of the pair plot
3942 color : str, optional, default : '#8f2727'
40- The color of the plot
43+ The primary color of the plot
4144 alpha : float in [0, 1], optional, default: 0.9
4245 The opacity of the plot
46+ label : str, optional, default: "Posterior"
47+ Label for the dataset to plot
4348 label_fontsize : int, optional, default: 14
4449 The font size of the x and y-label texts (parameter names)
4550 tick_fontsize : int, optional, default: 12
46- The font size of the axis ticklabels
51+ The font size of the axis tick labels
52+ show_single_legend : bool, optional, default: False
53+ Optional toggle for the user to choose whether a single dataset
54+ should also display legend
4755 **kwargs : dict, optional
4856 Additional keyword arguments passed to the sns.PairGrid constructor
4957 """
@@ -59,8 +67,11 @@ def pairs_samples(
5967 height = height ,
6068 color = color ,
6169 alpha = alpha ,
70+ label = label ,
6271 label_fontsize = label_fontsize ,
6372 tick_fontsize = tick_fontsize ,
73+ show_single_legend = show_single_legend ,
74+ ** kwargs ,
6475 )
6576
6677 return g
@@ -72,17 +83,27 @@ def _pairs_samples(
7283 color : str | tuple = "#132a70" ,
7384 color2 : str | tuple = "gray" ,
7485 alpha : float = 0.9 ,
86+ label : str = "Posterior" ,
7587 label_fontsize : int = 14 ,
7688 tick_fontsize : int = 12 ,
7789 legend_fontsize : int = 14 ,
90+ show_single_legend : bool = False ,
7891 ** kwargs ,
7992) -> sns .PairGrid :
80- # internal version of pairs_samples creating the seaborn plot
93+ """
94+ Internal version of pairs_samples creating the seaborn PairPlot
95+ for both a single dataset and multiple datasets.
8196
82- # Parameters
83- # ----------
84- # plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
85- # other arguments are documented in pairs_samples
97+ Parameters
98+ ----------
99+ plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
100+ Formatted data to plot from the sample dataset
101+ color2 : str, optional, default: 'gray'
102+ Secondary color for the pair plots.
103+ This is the color used for the prior draws.
104+
105+ Other arguments are documented in pairs_samples
106+ """
86107
87108 estimates_shape = plot_data ["estimates" ].shape
88109 if len (estimates_shape ) != 2 :
@@ -136,7 +157,7 @@ def _pairs_samples(
136157 common_norm = False ,
137158 )
138159
139- # add scatterplots to the upper diagonal
160+ # add scatter plots to the upper diagonal
140161 g .map_upper (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
141162
142163 # add KDEs to the lower diagonal
@@ -146,11 +167,6 @@ def _pairs_samples(
146167 logging .exception ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
147168 g .map_lower (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
148169
149- # need to add legend here such that colors are recognized
150- if plot_data ["priors" ] is not None :
151- g .add_legend (fontsize = legend_fontsize , loc = "center right" )
152- g ._legend .set_title (None )
153-
154170 # Generate grids
155171 dim = g .axes .shape [0 ]
156172 for i in range (dim ):
@@ -165,32 +181,48 @@ def _pairs_samples(
165181 g .axes [i , j ].tick_params (axis = "both" , which = "major" , labelsize = tick_fontsize )
166182 g .axes [i , j ].tick_params (axis = "both" , which = "minor" , labelsize = tick_fontsize )
167183
168- # adjust font size of labels
184+ # adjust the font size of labels
169185 # the labels themselves remain the same as before, i.e., variable_names
170186 g .axes [i , 0 ].set_ylabel (variable_names [i ], fontsize = label_fontsize )
171187 g .axes [dim - 1 , i ].set_xlabel (variable_names [i ], fontsize = label_fontsize )
172188
189+ # need to add legend here such that colors are recognized
190+ # if plot_data["priors"] is not None:
191+ # g.add_legend(fontsize=legend_fontsize, loc="center right")
192+ # g._legend.set_title(None)
193+
194+ create_legends (
195+ g ,
196+ plot_data ,
197+ color = color ,
198+ color2 = color2 ,
199+ legend_fontsize = legend_fontsize ,
200+ label = label ,
201+ show_single_legend = show_single_legend ,
202+ )
203+
173204 # Return figure
174205 g .tight_layout ()
175206
176207 return g
177208
178209
179- # create a histogram plot on a twin y axis
180- # this ensures that the y scaling of the diagonal plots
181- # in independent of the y scaling of the off-diagonal plots
182210def histplot_twinx (x , ** kwargs ):
183- # Create a twin axis
184- ax2 = plt .gca ().twinx ()
211+ """
212+ # create a histogram plot on a twin y-axis
213+ # this ensures that the y scaling of the diagonal plots
214+ # in independent of the y scaling of the off-diagonal plots
185215
216+ Parameters
217+ ----------
218+ x : np.ndarray
219+ Data to be plotted.
220+ """
186221 # create a histogram on the twin axis
187- sns .histplot (x , ** kwargs , ax = ax2 )
222+ sns .histplot (x , legend = False , ** kwargs )
188223
189224 # make the twin axis invisible
190225 plt .gca ().spines ["right" ].set_visible (False )
191226 plt .gca ().spines ["top" ].set_visible (False )
192- ax2 .set_ylabel ("" )
193- ax2 .set_yticks ([])
194- ax2 .set_yticklabels ([])
195227
196228 return None
0 commit comments