1313
1414def pairs_samples (
1515 samples : Mapping [str , np .ndarray ] | np .ndarray = None ,
16+ dataset_id : int = None ,
1617 variable_keys : Sequence [str ] = None ,
1718 variable_names : Sequence [str ] = None ,
1819 height : float = 2.5 ,
@@ -22,6 +23,7 @@ def pairs_samples(
2223 label_fontsize : int = 14 ,
2324 tick_fontsize : int = 12 ,
2425 show_single_legend : bool = False ,
26+ markersize : float = 40 ,
2527 ** kwargs ,
2628) -> sns .PairGrid :
2729 """
@@ -32,6 +34,8 @@ def pairs_samples(
3234 ----------
3335 samples : dict[str, Tensor], default: None
3436 Sample draws from any dataset
37+ dataset_id: Optional ID of the dataset for whose posterior the pair plots shall be generated.
38+ Should only be specified if estimates contain posterior draws from multiple datasets.
3539 variable_keys : list or None, optional, default: None
3640 Select keys from the dictionary provided in samples.
3741 By default, select all keys.
@@ -52,15 +56,23 @@ def pairs_samples(
5256 show_single_legend : bool, optional, default: False
5357 Optional toggle for the user to choose whether a single dataset
5458 should also display legend
59+ markersize : float, optional, default: 40
60+ Marker size in points**2 of the scatter plot.
5561 **kwargs : dict, optional
5662 Additional keyword arguments passed to the sns.PairGrid constructor
5763 """
5864
5965 plot_data = dicts_to_arrays (
6066 estimates = samples ,
67+ dataset_ids = dataset_id ,
6168 variable_keys = variable_keys ,
6269 variable_names = variable_names ,
6370 )
71+ # dicts_to_arrays will keep the dataset axis even if it is of length 1
72+ # however, pairs plotting requires the dataset axis to be removed
73+ estimates_shape = plot_data ["estimates" ].shape
74+ if len (estimates_shape ) == 3 and estimates_shape [0 ] == 1 :
75+ plot_data ["estimates" ] = np .squeeze (plot_data ["estimates" ], axis = 0 )
6476
6577 g = _pairs_samples (
6678 plot_data = plot_data ,
@@ -71,6 +83,7 @@ def pairs_samples(
7183 label_fontsize = label_fontsize ,
7284 tick_fontsize = tick_fontsize ,
7385 show_single_legend = show_single_legend ,
86+ markersize = markersize ,
7487 ** kwargs ,
7588 )
7689
@@ -88,6 +101,9 @@ def _pairs_samples(
88101 tick_fontsize : int = 12 ,
89102 legend_fontsize : int = 14 ,
90103 show_single_legend : bool = False ,
104+ markersize : float = 40 ,
105+ target_markersize : float = 40 ,
106+ target_color : str = "red" ,
91107 ** kwargs ,
92108) -> sns .PairGrid :
93109 """
@@ -101,6 +117,12 @@ def _pairs_samples(
101117 color2 : str, optional, default: 'gray'
102118 Secondary color for the pair plots.
103119 This is the color used for the prior draws.
120+ markersize : float, optional, default: 40
121+ Marker size in points**2 of the scatter plot.
122+ target_markersize : float, optional, default: 40
123+ Target marker size in points**2 of the scatter plot.
124+ target_color : str, optional, default: "red"
125+ Target marker color for the legend.
104126
105127 Other arguments are documented in pairs_samples
106128 """
@@ -159,14 +181,14 @@ def _pairs_samples(
159181 )
160182
161183 # add scatter plots to the upper diagonal
162- g .map_upper (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
184+ g .map_upper (sns .scatterplot , alpha = 0.6 , s = markersize , edgecolor = "k" , color = color , lw = 0 )
163185
164186 # add KDEs to the lower diagonal
165187 try :
166188 g .map_lower (sns .kdeplot , fill = True , color = color , alpha = alpha , common_norm = False )
167189 except Exception as e :
168190 logging .exception ("KDE failed due to the following exception:\n " + repr (e ) + "\n Substituting scatter plot." )
169- g .map_lower (sns .scatterplot , alpha = 0.6 , s = 40 , edgecolor = "k" , color = color , lw = 0 )
191+ g .map_lower (sns .scatterplot , alpha = 0.6 , s = markersize , edgecolor = "k" , color = color , lw = 0 )
170192
171193 # Generate grids
172194 dim = g .axes .shape [0 ]
@@ -200,6 +222,9 @@ def _pairs_samples(
200222 legend_fontsize = legend_fontsize ,
201223 label = label ,
202224 show_single_legend = show_single_legend ,
225+ markersize = markersize ,
226+ target_markersize = target_markersize ,
227+ target_color = target_color ,
203228 )
204229
205230 # Return figure
0 commit comments