Skip to content

Commit f79e625

Browse files
Add plot.max_subplots check before calling create_plotting_grid (#425)
* Format code with black * Disable plot.max_subplots in PlotMatrix tests using rc_context * Limit rc_context scope to PlotMatrix initialization * undo spurious formatter change --------- Co-authored-by: Oriol Abril-Pla <oriol.abril.pla@gmail.com>
1 parent 23d4323 commit f79e625

File tree

3 files changed

+60
-23
lines changed

3 files changed

+60
-23
lines changed

src/arviz_plots/plot_collection.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -763,6 +763,14 @@ def wrap(
763763

764764
n_plots, plots_per_var = process_facet_dims(data, cols)
765765

766+
max_plots = rcParams["plot.max_subplots"]
767+
if max_plots is not None and n_plots > max_plots:
768+
raise ValueError(
769+
f"Requested {n_plots} subplots, which exceeds "
770+
f"rcParams['plot.max_subplots']={max_plots}. "
771+
"Reduce the number of plots or increase this limit."
772+
)
773+
766774
if col_wrap is None:
767775
col_wrap = int(np.ceil(np.sqrt(n_plots)))
768776
else:
@@ -921,6 +929,13 @@ def grid(
921929
n_rows, rows_per_var = process_facet_dims(data, rows)
922930

923931
n_plots = n_cols * n_rows
932+
max_plots = rcParams["plot.max_subplots"]
933+
if max_plots is not None and n_plots > max_plots:
934+
raise ValueError(
935+
f"Requested {n_plots} subplots, which exceeds "
936+
f"rcParams['plot.max_subplots']={max_plots}. "
937+
"Reduce the number of plots or increase this limit."
938+
)
924939
plot_bknd = import_module(f".backend.{backend}", package="arviz_plots")
925940
fig, ax_ary = plot_bknd.create_plotting_grid(
926941
n_plots, n_rows, n_cols, squeeze=False, **figure_kwargs

src/arviz_plots/plot_matrix.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,13 @@ def _generate_viz_dt(self, **figure_kwargs):
136136
)
137137
n_pairs = len(pairs)
138138
n_plots = n_pairs**2
139+
max_plots = rcParams["plot.max_subplots"]
140+
if max_plots is not None and n_plots > max_plots:
141+
raise ValueError(
142+
f"Requested {n_plots} subplots, which exceeds "
143+
f"rcParams['plot.max_subplots']={max_plots}. "
144+
"Reduce the number of plots or increase this limit."
145+
)
139146
plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots")
140147
fig, ax_ary = plot_bknd.create_plotting_grid(
141148
n_plots, n_pairs, n_pairs, squeeze=False, **figure_kwargs

tests/test_plot_matrix.py

Lines changed: 38 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
import xarray as xr
6-
from arviz_base import dict_to_dataset
6+
from arviz_base import dict_to_dataset, rc_context
77

88
from arviz_plots import PlotMatrix
99
from arviz_plots.plot_matrix import subset_matrix_da
@@ -55,7 +55,8 @@ def test_subset_matrix_da_offdiag(matrix_da, subset_x, subset_y):
5555

5656

5757
def test_plot_matrix_init(dataset):
58-
pc = PlotMatrix(dataset, ["__variable__", "hierarchy", "group"], backend="none")
58+
with rc_context({"plot.max_subplots": None}):
59+
pc = PlotMatrix(dataset, ["__variable__", "hierarchy", "group"], backend="none")
5960
assert "plot" in pc.viz.data_vars
6061
coord_names = ("var_name_x", "var_name_y", "hierarchy_x", "hierarchy_y", "group_x", "group_y")
6162
missing_coord_names = [name for name in coord_names if name not in pc.viz["plot"].coords]
@@ -64,9 +65,13 @@ def test_plot_matrix_init(dataset):
6465

6566

6667
def test_plot_matrix_aes(dataset):
67-
pc = PlotMatrix(
68-
dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]}
69-
)
68+
with rc_context({"plot.max_subplots": None}):
69+
pc = PlotMatrix(
70+
dataset,
71+
["__variable__", "hierarchy", "group"],
72+
backend="none",
73+
aes={"color": ["chain"]},
74+
)
7075
assert "/color" in pc.aes.groups
7176
assert "mapping" in pc.aes["color"].data_vars
7277
assert "neutral_element" not in pc.aes["color"].data_vars
@@ -87,9 +92,13 @@ def map_auxiliar_couple(da_x, da_y, target, target_list, kwarg_list, **kwargs):
8792

8893

8994
def test_plot_matrix_map(dataset):
90-
pc = PlotMatrix(
91-
dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]}
92-
)
95+
with rc_context({"plot.max_subplots": None}):
96+
pc = PlotMatrix(
97+
dataset,
98+
["__variable__", "hierarchy", "group"],
99+
backend="none",
100+
aes={"color": ["chain"]},
101+
)
93102
target_list = []
94103
kwarg_list = []
95104
pc.map(
@@ -113,12 +122,13 @@ def test_plot_matrix_map(dataset):
113122

114123

115124
def test_plot_matrix_map_scalar_coord(dataset):
116-
pc = PlotMatrix(
117-
dataset.isel(hierarchy=[0]),
118-
["__variable__", "hierarchy", "group"],
119-
backend="none",
120-
aes={"color": ["chain"]},
121-
)
125+
with rc_context({"plot.max_subplots": None}):
126+
pc = PlotMatrix(
127+
dataset.isel(hierarchy=[0]),
128+
["__variable__", "hierarchy", "group"],
129+
backend="none",
130+
aes={"color": ["chain"]},
131+
)
122132
target_list = []
123133
kwarg_list = []
124134
pc.map(
@@ -143,9 +153,13 @@ def test_plot_matrix_map_scalar_coord(dataset):
143153

144154
@pytest.mark.parametrize("triangle", ("both", "lower", "upper"))
145155
def test_plot_matrix_map_triangle(dataset, triangle):
146-
pc = PlotMatrix(
147-
dataset, ["__variable__", "hierarchy", "group"], backend="none", aes={"color": ["chain"]}
148-
)
156+
with rc_context({"plot.max_subplots": None}):
157+
pc = PlotMatrix(
158+
dataset,
159+
["__variable__", "hierarchy", "group"],
160+
backend="none",
161+
aes={"color": ["chain"]},
162+
)
149163
target_list = []
150164
kwarg_list = []
151165
pc.map_triangle(
@@ -179,12 +193,13 @@ def test_plot_matrix_map_triangle(dataset, triangle):
179193

180194
@pytest.mark.parametrize("triangle", ("both", "lower", "upper"))
181195
def test_plot_matrix_map_triangle_scalar_coord(dataset, triangle):
182-
pc = PlotMatrix(
183-
dataset.isel(hierarchy=[0]),
184-
["__variable__", "hierarchy", "group"],
185-
backend="none",
186-
aes={"color": ["chain"]},
187-
)
196+
with rc_context({"plot.max_subplots": None}):
197+
pc = PlotMatrix(
198+
dataset.isel(hierarchy=[0]),
199+
["__variable__", "hierarchy", "group"],
200+
backend="none",
201+
aes={"color": ["chain"]},
202+
)
188203
target_list = []
189204
kwarg_list = []
190205
pc.map_triangle(

0 commit comments

Comments
 (0)