@@ -40,13 +40,18 @@ def pairs_samples(
4040 height : float, optional, default: 2.5
4141 The height of the pair plot
4242 color : str, optional, default : '#8f2727'
43- The color of the plot
43+ The primary color of the plot
4444 alpha : float in [0, 1], optional, default: 0.9
4545 The opacity of the plot
46+ label : str, optional, default: "Posterior"
47+ Label for the dataset to plot
4648 label_fontsize : int, optional, default: 14
4749 The font size of the x and y-label texts (parameter names)
4850 tick_fontsize : int, optional, default: 12
49- The font size of the axis ticklabels
51+ The font size of the axis tick labels
52+ show_single_legend : bool, optional, default: False
53+ Optional toggle for the user to choose whether a single dataset
54+ should also display legend
5055 **kwargs : dict, optional
5156 Additional keyword arguments passed to the sns.PairGrid constructor
5257 """
@@ -85,12 +90,20 @@ def _pairs_samples(
8590 show_single_legend : bool = False ,
8691 ** kwargs ,
8792) -> sns .PairGrid :
88- # internal version of pairs_samples creating the seaborn plot
93+ """
94+ Internal version of pairs_samples creating the seaborn PairPlot
95+ for both a single dataset and multiple datasets.
8996
90- # Parameters
91- # ----------
92- # plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
93- # other arguments are documented in pairs_samples
97+ Parameters
98+ ----------
99+ plot_data : output of bayesflow.utils.dict_utils.dicts_to_arrays
100+ Formatted data to plot from the sample dataset
101+ color2 : str, optional, default: 'gray'
102+ Secondary color for the pair plots.
103+ This is the color used for the prior draws.
104+
105+ Other arguments are documented in pairs_samples
106+ """
94107
95108 estimates_shape = plot_data ["estimates" ].shape
96109 if len (estimates_shape ) != 2 :
@@ -144,7 +157,7 @@ def _pairs_samples(
144157 common_norm = False ,
145158 )
146159
147- # add scatterplots to the upper diagonal
160+ # add scatter plots to the upper diagonal
148161 g .map_upper (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
149162
150163 # add KDEs to the lower diagonal
@@ -168,7 +181,7 @@ def _pairs_samples(
168181 g .axes [i , j ].tick_params (axis = "both" , which = "major" , labelsize = tick_fontsize )
169182 g .axes [i , j ].tick_params (axis = "both" , which = "minor" , labelsize = tick_fontsize )
170183
171- # adjust font size of labels
184+ # adjust the font size of labels
172185 # the labels themselves remain the same as before, i.e., variable_names
173186 g .axes [i , 0 ].set_ylabel (variable_names [i ], fontsize = label_fontsize )
174187 g .axes [dim - 1 , i ].set_xlabel (variable_names [i ], fontsize = label_fontsize )
@@ -196,7 +209,7 @@ def _pairs_samples(
196209
197210def histplot_twinx (x , ** kwargs ):
198211 """
199- # create a histogram plot on a twin y axis
212+ # create a histogram plot on a twin y- axis
200213 # this ensures that the y scaling of the diagonal plots
201214 # in independent of the y scaling of the off-diagonal plots
202215
0 commit comments