Skip to content

Commit 896da8d

Browse files
tvwengerOriolAbril
andauthored
Improve pair_plot's reference_values compatibility, flexibility, and documentation (#2438)
* assign reference values by var_name instead of var_label * improve pair_plot reference_values flexibility * catch missing reference_values * improve bokeh test coverage * update matplotlib pair_plot tests * add plot_pair reference_values example * add reference_values documentation * update changelog * revert some bokeh pair_plot tests * revert some bokeh pair_plot tests * bokeh reference_values example * update reference_values docs --------- Co-authored-by: Oriol Abril-Pla <[email protected]>
1 parent e16cde2 commit 896da8d

File tree

9 files changed

+161
-74
lines changed

9 files changed

+161
-74
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
## Unreleased
44

55
### New features
6+
- `plot_pair` now has more flexible support for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
67
- Make `arviz.from_numpyro(..., dims=None)` automatically infer dims from the numpyro model based on its numpyro.plate structure
78

9+
810
### Maintenance and fixes
11+
- `reference_values` and `labeller` now work together in `plot_pair` ([2437](https://github.com/arviz-devs/arviz/issues/2437))
912

1013
- Add [`scipy-stubs`](https://github.com/scipy/scipy-stubs) as a development dependency ([2445](https://github.com/arviz-devs/arviz/pull/2445))
1114

1215
### Documentation
16+
- Added documentation for `reference_values` ([2438](https://github.com/arviz-devs/arviz/pull/2438))
1317

1418
## v0.21.0 (2025 Mar 06)
1519

arviz/plots/backends/bokeh/pairplot.py

Lines changed: 18 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def plot_pair(
3737
diverging_mask,
3838
divergences_kwargs,
3939
flat_var_names,
40+
flat_ref_slices,
41+
flat_var_labels,
4042
backend_kwargs,
4143
marginal_kwargs,
4244
show,
@@ -72,50 +74,12 @@ def plot_pair(
7274
kde_kwargs["contour_kwargs"].setdefault("line_alpha", 1)
7375

7476
if reference_values:
75-
reference_values_copy = {}
76-
label = []
77-
for variable in list(reference_values.keys()):
78-
if " " in variable:
79-
variable_copy = variable.replace(" ", "\n", 1)
80-
else:
81-
variable_copy = variable
82-
83-
label.append(variable_copy)
84-
reference_values_copy[variable_copy] = reference_values[variable]
85-
86-
difference = set(flat_var_names).difference(set(label))
87-
88-
if difference:
89-
warn = [diff.replace("\n", " ", 1) for diff in difference]
90-
warnings.warn(
91-
"Argument reference_values does not include reference value for: {}".format(
92-
", ".join(warn)
93-
),
94-
UserWarning,
95-
)
96-
97-
if reference_values:
98-
reference_values_copy = {}
99-
label = []
100-
for variable in list(reference_values.keys()):
101-
if " " in variable:
102-
variable_copy = variable.replace(" ", "\n", 1)
103-
else:
104-
variable_copy = variable
105-
106-
label.append(variable_copy)
107-
reference_values_copy[variable_copy] = reference_values[variable]
108-
109-
difference = set(flat_var_names).difference(set(label))
110-
111-
for dif in difference:
112-
reference_values_copy[dif] = None
77+
difference = set(flat_var_names).difference(set(reference_values.keys()))
11378

11479
if difference:
115-
warn = [dif.replace("\n", " ", 1) for dif in difference]
11680
warnings.warn(
11781
"Argument reference_values does not include reference value for: {}".format(
118-
", ".join(warn)
82+
", ".join(difference)
11983
),
12084
UserWarning,
12185
)
@@ -262,8 +226,8 @@ def get_width_and_height(jointplot, rotate):
262226
**marginal_kwargs,
263227
)
264228

265-
ax[j, i].xaxis.axis_label = flat_var_names[i]
266-
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
229+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
230+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
267231

268232
elif j + marginals_offset > i:
269233
if "scatter" in kind:
@@ -346,12 +310,18 @@ def get_width_and_height(jointplot, rotate):
346310
ax[-1, -1].add_layout(ax_pe_hline)
347311

348312
if reference_values:
349-
x = reference_values_copy[flat_var_names[j + marginals_offset]]
350-
y = reference_values_copy[flat_var_names[i]]
351-
if x and y:
352-
ax[j, i].scatter(y, x, **reference_values_kwargs)
353-
ax[j, i].xaxis.axis_label = flat_var_names[i]
354-
ax[j, i].yaxis.axis_label = flat_var_names[j + marginals_offset]
313+
x_name = flat_var_names[j + marginals_offset]
314+
y_name = flat_var_names[i]
315+
if (x_name not in difference) and (y_name not in difference):
316+
ax[j, i].scatter(
317+
np.array(reference_values[y_name])[flat_ref_slices[i]],
318+
np.array(reference_values[x_name])[
319+
flat_ref_slices[j + marginals_offset]
320+
],
321+
**reference_values_kwargs,
322+
)
323+
ax[j, i].xaxis.axis_label = flat_var_labels[i]
324+
ax[j, i].yaxis.axis_label = flat_var_labels[j + marginals_offset]
355325

356326
show_layout(ax, show)
357327

arviz/plots/backends/matplotlib/pairplot.py

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def plot_pair(
3030
diverging_mask,
3131
divergences_kwargs,
3232
flat_var_names,
33+
flat_ref_slices,
34+
flat_var_labels,
3335
backend_kwargs,
3436
marginal_kwargs,
3537
show,
@@ -77,24 +79,12 @@ def plot_pair(
7779
kde_kwargs["contour_kwargs"].setdefault("colors", "k")
7880

7981
if reference_values:
80-
reference_values_copy = {}
81-
label = []
82-
for variable in list(reference_values.keys()):
83-
if " " in variable:
84-
variable_copy = variable.replace(" ", "\n", 1)
85-
else:
86-
variable_copy = variable
87-
88-
label.append(variable_copy)
89-
reference_values_copy[variable_copy] = reference_values[variable]
90-
91-
difference = set(flat_var_names).difference(set(label))
82+
difference = set(flat_var_names).difference(set(reference_values.keys()))
9283

9384
if difference:
94-
warn = [diff.replace("\n", " ", 1) for diff in difference]
9585
warnings.warn(
9686
"Argument reference_values does not include reference value for: {}".format(
97-
", ".join(warn)
87+
", ".join(difference)
9888
),
9989
UserWarning,
10090
)
@@ -211,12 +201,12 @@ def plot_pair(
211201

212202
if reference_values:
213203
ax.plot(
214-
reference_values_copy[flat_var_names[0]],
215-
reference_values_copy[flat_var_names[1]],
204+
np.array(reference_values[flat_var_names[0]])[flat_ref_slices[0]],
205+
np.array(reference_values[flat_var_names[1]])[flat_ref_slices[1]],
216206
**reference_values_kwargs,
217207
)
218-
ax.set_xlabel(f"{flat_var_names[0]}", fontsize=ax_labelsize, wrap=True)
219-
ax.set_ylabel(f"{flat_var_names[1]}", fontsize=ax_labelsize, wrap=True)
208+
ax.set_xlabel(f"{flat_var_labels[0]}", fontsize=ax_labelsize, wrap=True)
209+
ax.set_ylabel(f"{flat_var_labels[1]}", fontsize=ax_labelsize, wrap=True)
220210
ax.tick_params(labelsize=xt_labelsize)
221211

222212
else:
@@ -336,20 +326,22 @@ def plot_pair(
336326
y_name = flat_var_names[j + not_marginals]
337327
if (x_name not in difference) and (y_name not in difference):
338328
ax[j, i].plot(
339-
reference_values_copy[x_name],
340-
reference_values_copy[y_name],
329+
np.array(reference_values[x_name])[flat_ref_slices[i]],
330+
np.array(reference_values[y_name])[
331+
flat_ref_slices[j + not_marginals]
332+
],
341333
**reference_values_kwargs,
342334
)
343335

344336
if j != vars_to_plot - 1:
345337
plt.setp(ax[j, i].get_xticklabels(), visible=False)
346338
else:
347-
ax[j, i].set_xlabel(f"{flat_var_names[i]}", fontsize=ax_labelsize, wrap=True)
339+
ax[j, i].set_xlabel(f"{flat_var_labels[i]}", fontsize=ax_labelsize, wrap=True)
348340
if i != 0:
349341
plt.setp(ax[j, i].get_yticklabels(), visible=False)
350342
else:
351343
ax[j, i].set_ylabel(
352-
f"{flat_var_names[j + not_marginals]}",
344+
f"{flat_var_labels[j + not_marginals]}",
353345
fontsize=ax_labelsize,
354346
wrap=True,
355347
)

arviz/plots/pairplot.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,14 @@ def plot_pair(
196196
get_coords(dataset, coords), var_names=var_names, skip_dims=combine_dims, combined=True
197197
)
198198
)
199-
flat_var_names = [
200-
labeller.make_label_vert(var_name, sel, isel) for var_name, sel, isel, _ in plotters
201-
]
199+
flat_var_names = []
200+
flat_ref_slices = []
201+
flat_var_labels = []
202+
for var_name, sel, isel, _ in plotters:
203+
dims = [dim for dim in dataset[var_name].dims if dim not in ["chain", "draw"]]
204+
flat_var_names.append(var_name)
205+
flat_ref_slices.append(tuple(isel[dim] if dim in isel else slice(None) for dim in dims))
206+
flat_var_labels.append(labeller.make_label_vert(var_name, sel, isel))
202207

203208
divergent_data = None
204209
diverging_mask = None
@@ -253,6 +258,8 @@ def plot_pair(
253258
diverging_mask=diverging_mask,
254259
divergences_kwargs=divergences_kwargs,
255260
flat_var_names=flat_var_names,
261+
flat_ref_slices=flat_ref_slices,
262+
flat_var_labels=flat_var_labels,
256263
backend_kwargs=backend_kwargs,
257264
marginal_kwargs=marginal_kwargs,
258265
show=show,

arviz/tests/base_tests/test_plots_bokeh.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from scipy.stats import norm # pylint: disable=wrong-import-position
99

1010
from ...data import from_dict, load_arviz_data # pylint: disable=wrong-import-position
11+
from ...labels import MapLabeller # pylint: disable=wrong-import-position
1112
from ...plots import ( # pylint: disable=wrong-import-position
1213
plot_autocorr,
1314
plot_bpv,
@@ -773,7 +774,6 @@ def test_plot_mcse_no_divergences(models):
773774
{"divergences": True, "var_names": ["theta", "mu"]},
774775
{"kind": "kde", "var_names": ["theta"]},
775776
{"kind": "hexbin", "var_names": ["theta"]},
776-
{"kind": "hexbin", "var_names": ["theta"]},
777777
{
778778
"kind": "hexbin",
779779
"var_names": ["theta"],
@@ -785,6 +785,21 @@ def test_plot_mcse_no_divergences(models):
785785
"reference_values": {"mu": 0, "tau": 0},
786786
"reference_values_kwargs": {"line_color": "blue"},
787787
},
788+
{
789+
"var_names": ["mu", "tau"],
790+
"reference_values": {"mu": 0, "tau": 0},
791+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
792+
},
793+
{
794+
"var_names": ["theta"],
795+
"reference_values": {"theta": [0.0] * 8},
796+
"labeller": MapLabeller({"theta": r"$\theta$"}),
797+
},
798+
{
799+
"var_names": ["theta"],
800+
"reference_values": {"theta": np.zeros(8)},
801+
"labeller": MapLabeller({"theta": r"$\theta$"}),
802+
},
788803
],
789804
)
790805
def test_plot_pair(models, kwargs):

arviz/tests/base_tests/test_plots_matplotlib.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from scipy.stats import gaussian_kde, norm
1515

1616
from ...data import from_dict, load_arviz_data
17+
from ...labels import MapLabeller
1718
from ...plots import (
1819
plot_autocorr,
1920
plot_bf,
@@ -599,6 +600,21 @@ def test_plot_kde_inference_data(models):
599600
"reference_values": {"mu": 0, "tau": 0},
600601
"reference_values_kwargs": {"c": "C0", "marker": "*"},
601602
},
603+
{
604+
"var_names": ["mu", "tau"],
605+
"reference_values": {"mu": 0, "tau": 0},
606+
"labeller": MapLabeller({"mu": r"$\mu$", "theta": r"$\theta"}),
607+
},
608+
{
609+
"var_names": ["theta"],
610+
"reference_values": {"theta": [0.0] * 8},
611+
"labeller": MapLabeller({"theta": r"$\theta$"}),
612+
},
613+
{
614+
"var_names": ["theta"],
615+
"reference_values": {"theta": np.zeros(8)},
616+
"labeller": MapLabeller({"theta": r"$\theta$"}),
617+
},
602618
],
603619
)
604620
def test_plot_pair(models, kwargs):

doc/source/user_guide/plots_arguments_guide.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,3 +390,28 @@ These are kwargs specific to the backend being used, passed to `matplotlib.pyplo
390390
## `show`
391391

392392
Call backend show function.
393+
394+
(common_reference_values)=
395+
## `reference_values`
396+
`plot_pair` accepts `reference_values` to highlight specific values on the probability distributions. The keys of `reference_values` are the associated variable names in `var_names`. The values are the reference values, which must have the same shape as the coordinates selected for plotting since it is indexed as such. For example, here `theta` must have shape `(2,)` since that is the shape of the selected coordinates on `theta`.
397+
398+
```{code-cell} ipython3
399+
coords = {"school": ["Choate", "Deerfield"]}
400+
reference_values = {
401+
"mu": 0.0,
402+
"theta": np.zeros(2),
403+
}
404+
az.plot_pair(centered_eight, var_names=["mu", "theta"], coords=coords, reference_values=reference_values);
405+
```
406+
407+
When used with `combine_dims`, each reference value along the combined dimension is plotted on the same axis.
408+
```{code-cell} ipython3
409+
coords = {"school": ["Choate", "Deerfield"]}
410+
reference_values = {
411+
"theta": [-5.0, 5.0],
412+
"theta_t": [-2.0, 2.0],
413+
}
414+
az.plot_pair(non_centered_eight, var_names=["theta", "theta_t"], coords=coords, reference_values=reference_values, combine_dims={"school"});
415+
```
416+
417+
The values of the `reference_values` dictionary can be scalars (e.g., `0`) or zero-dimensional `numpy` arrays (e.g., `np.array(0)`) for scalar variables, or anything that can be cast to `np.array` (e.g., `[0.0, 0.0]` or `np.array([0.0, 0.0])`) for multi-dimensional variables.
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
"""
2+
Pairplot with Reference Values
3+
==============================
4+
"""
5+
6+
import arviz as az
7+
import numpy as np
8+
9+
data = az.load_arviz_data("centered_eight")
10+
11+
coords = {"school": ["Choate", "Deerfield"]}
12+
reference_values = {
13+
"mu": 0.0,
14+
"theta": np.zeros(2),
15+
}
16+
17+
ax = az.plot_pair(
18+
data,
19+
var_names=["mu", "theta"],
20+
kind=["scatter", "kde"],
21+
kde_kwargs={"fill_last": False},
22+
coords=coords,
23+
reference_values=reference_values,
24+
figsize=(11.5, 5),
25+
backend="bokeh",
26+
)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
"""
2+
Pairplot with Reference Values
3+
==============================
4+
_gallery_category: Distributions
5+
"""
6+
7+
import matplotlib.pyplot as plt
8+
9+
import arviz as az
10+
import numpy as np
11+
12+
az.style.use("arviz-doc")
13+
14+
data = az.load_arviz_data("centered_eight")
15+
16+
coords = {"school": ["Choate", "Deerfield"]}
17+
reference_values = {
18+
"mu": 0.0,
19+
"theta": np.zeros(2),
20+
}
21+
22+
ax = az.plot_pair(
23+
data,
24+
var_names=["mu", "theta"],
25+
kind=["scatter", "kde"],
26+
kde_kwargs={"fill_last": False},
27+
coords=coords,
28+
reference_values=reference_values,
29+
figsize=(11.5, 5),
30+
)
31+
32+
plt.show()

0 commit comments

Comments
 (0)