Skip to content

Commit 4743beb

Browse files
Coords kwarg added to plot_autocorr function (#2404)
* Coordinates keyword added to plot_autocorr along with examples and test cases * Test file modified to pass pylint test * undo example modification --------- Co-authored-by: Oriol (ProDesk) <[email protected]>
1 parent 1d8c010 commit 4743beb

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

arviz/plots/autocorrplot.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ..labels import BaseLabeller
55
from ..sel_utils import xarray_var_iter
66
from ..rcparams import rcParams
7-
from ..utils import _var_names
7+
from ..utils import _var_names, get_coords
88
from .plot_utils import default_grid, filter_plotters_list, get_plotting_function
99

1010

@@ -14,6 +14,7 @@ def plot_autocorr(
1414
filter_vars=None,
1515
max_lag=None,
1616
combined=False,
17+
coords=None,
1718
grid=None,
1819
figsize=None,
1920
textsize=None,
@@ -42,6 +43,8 @@ def plot_autocorr(
4243
interpret `var_names` as substrings of the real variables names. If "regex",
4344
interpret `var_names` as regular expressions on the real variables names. See
4445
:ref:`this section <common_filter_vars>` for usage examples.
46+
coords: mapping, optional
47+
Coordinates of var_names to be plotted. Passed to :meth:`xarray.Dataset.sel`
4548
max_lag : int, optional
4649
Maximum lag to calculate autocorrelation. By Default, the plot displays the
4750
first 100 lag or the total number of draws, whichever is smaller.
@@ -124,11 +127,18 @@ def plot_autocorr(
124127
if max_lag is None:
125128
max_lag = min(100, data["draw"].shape[0])
126129

130+
if coords is None:
131+
coords = {}
132+
127133
if labeller is None:
128134
labeller = BaseLabeller()
129135

130136
plotters = filter_plotters_list(
131-
list(xarray_var_iter(data, var_names, combined, dim_order=["chain", "draw"])),
137+
list(
138+
xarray_var_iter(
139+
get_coords(data, coords), var_names, combined, dim_order=["chain", "draw"]
140+
)
141+
),
132142
"plot_autocorr",
133143
)
134144
rows, cols = default_grid(len(plotters), grid=grid)

arviz/tests/base_tests/test_plots_matplotlib.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,19 +2151,6 @@ def test_plot_lm(data, kind, use_y_model):
21512151
else:
21522152
idata = generate_lm_2d_data()
21532153

2154-
# test_cases = [
2155-
# # Single dimensional cases
2156-
# (data_1d, None, "lines", True, 50), # y_model with lines, default samples
2157-
# (data_1d, None, "hdi", True, None), # y_model with hdi, no samples needed
2158-
# (data_1d, None, "lines", False, 50), # without y_model, lines
2159-
# (data_1d, None, "hdi", False, None), # without y_model, hdi
2160-
# # Multi-dimensional cases with plot_dim
2161-
# (data_2d, "dim1", "lines", True, 20), # y_model with lines, fewer samples
2162-
# (data_2d, "dim1", "hdi", True, None), # y_model with hdi
2163-
# (data_2d, "dim1", "lines", False, 50), # without y_model, lines
2164-
# (data_2d, "dim1", "hdi", False, None), # without y_model, hdi
2165-
# ]
2166-
21672154
kwargs = {"idata": idata, "y": "y", "kind_model": kind}
21682155
if data == "2d":
21692156
kwargs["plot_dim"] = "dim1"
@@ -2174,3 +2161,20 @@ def test_plot_lm(data, kind, use_y_model):
21742161

21752162
ax = plot_lm(**kwargs)
21762163
assert ax is not None
2164+
2165+
2166+
@pytest.mark.parametrize(
2167+
"coords, expected_vars",
2168+
[
2169+
({"school": ["Choate"]}, ["theta"]),
2170+
({"school": ["Lawrenceville"]}, ["theta"]),
2171+
({}, ["theta"]),
2172+
],
2173+
)
2174+
def test_plot_autocorr_coords(coords, expected_vars):
2175+
"""Test plot_autocorr with coords kwarg."""
2176+
idata = load_arviz_data("centered_eight")
2177+
2178+
axes = plot_autocorr(idata, var_names=expected_vars, coords=coords, show=False)
2179+
2180+
assert axes is not None

0 commit comments

Comments
 (0)