1- from collections .abc import Sequence , Mapping
1+ from collections .abc import Callable , Sequence , Mapping
22
33import matplotlib
44import matplotlib .pyplot as plt
77import pandas as pd
88import seaborn as sns
99
10- from bayesflow .utils .dict_utils import make_variable_array
10+
11+ from .plot_quantity import _prepare_values
1112
1213
1314def pairs_quantity (
14- values : Mapping [str , np .ndarray ] | np .ndarray ,
15+ values : Mapping [str , np .ndarray ] | np .ndarray | Callable ,
16+ * ,
1517 targets : Mapping [str , np .ndarray ] | np .ndarray ,
1618 variable_keys : Sequence [str ] = None ,
1719 variable_names : Sequence [str ] = None ,
20+ estimates : Mapping [str , np .ndarray ] | np .ndarray | None = None ,
21+ test_quantities : dict [str , Callable ] = None ,
1822 height : float = 2.5 ,
1923 cmap : str | matplotlib .colors .Colormap = "viridis" ,
2024 alpha : float = 0.9 ,
21- label : str = "" ,
25+ label : str = None ,
2226 label_fontsize : int = 14 ,
2327 tick_fontsize : int = 12 ,
2428 colorbar_label_fontsize : int = 14 ,
@@ -28,6 +32,7 @@ def pairs_quantity(
2832 colorbar_offset : float = 0.06 ,
2933 vmin : float = None ,
3034 vmax : float = None ,
35+ default_name : str = "v" ,
3136 ** kwargs ,
3237) -> sns .PairGrid :
3338 """
@@ -38,25 +43,59 @@ def pairs_quantity(
3843 each parameter is plotted on the diagonal. Each column displays the
3944 values of corresponding to the parameter in the column.
4045
46+ The function supports the following different combinations to pass
47+ or compute the values:
48+
49+ 1. pass `values` as an array of shape (num_datasets,) or (num_datasets, num_variables)
50+ 2. pass `values` as a dictionary with the keys 'values', 'metric_name' and 'variable_names'
51+ as provided by the metrics functions. Note that the functions have to be called
52+ without aggregation to obtain value per dataset.
53+ 3. pass a function to `values`, as well as `estimates`. The function should have the
54+ signature fn(estimates, targets, [aggregation]) and return an object like the
55+ `values` described in the previous options.
56+
4157 Parameters
4258 ----------
43- values : dict[str, np.ndarray],
44- The value of the quantity to plot.
45- targets : dict[str, np.ndarray],
59+ values : dict[str, np.ndarray] | np.ndarray | Callable,
60+ The value of the quantity to plot. One of the following:
61+
62+ 1. an array of shape (num_datasets,) or (num_datasets, num_variables)
63+ 2. a dictionary with the keys 'values', 'metric_name' and 'variable_names'
64+ as provided by the metrics functions. Note that the functions have to be called
65+ without aggregation to obtain value per dataset.
66+ 3. a callable, requires passing `estimates` as well. The function should have the
67+ signature fn(estimates, targets, [aggregation]) and return an object like the
68+ ones described in the previous options.
69+ targets : dict[str, np.ndarray] | np.ndarray,
4670 The parameter values plotted on the axis.
4771 variable_keys : list or None, optional, default: None
4872 Select keys from the dictionary provided in samples.
4973 By default, select all keys.
5074 variable_names : list or None, optional, default: None
5175 The parameter names for nice plot titles. Inferred if None
76+ estimates : np.ndarray of shape (n_data_sets, n_post_draws, n_params), optional, default: None
77+ The posterior draws obtained from n_data_sets. Can only be supplied if
78+ `values` is of type Callable.
79+ test_quantities : dict or None, optional, default: None
80+ A dict that maps plot titles to functions that compute
81+ test quantities based on estimate/target draws.
82+
83+ The dict keys are automatically added to ``variable_keys``
84+ and ``variable_names``.
85+ Test quantity functions are expected to accept a dict of draws with
86+ shape ``(batch_size, ...)`` as the first (typically only)
87+ positional argument and return an NumPy array of shape
88+ ``(batch_size,)``.
89+ The functions do not have to deal with an additional
90+ sample dimension, as appropriate reshaping is done internally.
5291 height : float, optional, default: 2.5
5392 The height of the pair plot
5493 cmap : str or Colormap, default: "viridis"
5594 The colormap for the plot.
5695 alpha : float in [0, 1], optional, default: 0.9
5796 The opacity of the plot
58- label : str, optional, default: ""
59- Label for the dataset to plot
97+ label : str, optional, default: None
98+ Label for the dataset to plot.
6099 label_fontsize : int, optional, default: 14
61100 The font size of the x and y-label texts (parameter names)
62101 tick_fontsize : int, optional, default: 12
@@ -77,21 +116,44 @@ def pairs_quantity(
77116 vmax : float, optional, default: None
78117 Maximum value for the colormap. If None, the maximum value is
79118 determined from `values`.
119+ default_name : str, optional (default = "v")
120+ The default name to use for estimates if None provided
80121 **kwargs : dict, optional
81122 Additional keyword arguments passed to the sns.PairGrid constructor
123+
124+ Returns
125+ -------
126+ plt.Figure
127+ The figure instance
128+
129+ Raises
130+ ------
131+ ValueError
132+ If a callable is supplied as `values`, but `estimates` is None.
82133 """
83- values = make_variable_array (
84- values ,
134+
135+ if isinstance (values , Callable ) and estimates is None :
136+ raise ValueError ("Supplied a callable as `values`, but not `estimates`." )
137+
138+ d = _prepare_values (
139+ values = values ,
140+ targets = targets ,
141+ estimates = estimates ,
85142 variable_keys = variable_keys ,
86143 variable_names = variable_names ,
144+ test_quantities = test_quantities ,
145+ label = label ,
146+ default_name = default_name ,
87147 )
88- variable_names = values .variable_names
89- variable_keys = values .variable_keys
90- targets = make_variable_array (
91- targets ,
92- variable_keys = variable_keys ,
93- variable_names = variable_names ,
148+ (values , targets , variable_keys , variable_names , test_quantities , label ) = (
149+ d ["values" ],
150+ d ["targets" ],
151+ d ["variable_keys" ],
152+ d ["variable_names" ],
153+ d ["test_quantities" ],
154+ d ["label" ],
94155 )
156+
95157 # Convert samples to pd.DataFrame
96158 data_to_plot = pd .DataFrame (targets , columns = variable_names )
97159
@@ -110,11 +172,12 @@ def pairs_quantity(
110172 dim = g .axes .shape [0 ]
111173 for i in range (dim ):
112174 for j in range (dim ):
175+ # if one value for each variable is supplied, use it for the corresponding column
176+ row_values = values [:, j ] if values .ndim == 2 else values
177+
113178 if i == j :
114179 ax = g .axes [i , j ].twinx ()
115- ax .scatter (
116- targets [:, i ], values [:, i ], c = values [:, i ], cmap = cmap , s = 4 , vmin = vmin , vmax = vmax , alpha = alpha
117- )
180+ ax .scatter (targets [:, i ], values [:, i ], c = row_values , cmap = cmap , s = 4 , vmin = vmin , vmax = vmax , alpha = alpha )
118181 ax .spines ["left" ].set_visible (False )
119182 ax .spines ["top" ].set_visible (False )
120183 ax .tick_params (axis = "both" , which = "major" , labelsize = tick_fontsize )
@@ -132,7 +195,7 @@ def pairs_quantity(
132195 g .axes [i , j ].scatter (
133196 targets [:, j ],
134197 targets [:, i ],
135- c = values [:, j ] ,
198+ c = row_values ,
136199 cmap = cmap ,
137200 s = 4 ,
138201 vmin = vmin ,
0 commit comments