Skip to content

Commit 48b7c98

Browse files
authored
Merge pull request matplotlib#19964 from tacaswell/fix_mosaic_order
FIX: add subplot_mosaic axes in the order the user gave them to us
2 parents e6b8004 + 3dbd8ca commit 48b7c98

File tree

2 files changed

+83
-24
lines changed

2 files changed

+83
-24
lines changed

lib/matplotlib/figure.py

Lines changed: 61 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,11 +1789,12 @@ def _identify_keys_and_nested(layout):
17891789
17901790
Returns
17911791
-------
1792-
unique_ids : set
1792+
unique_ids : tuple
17931793
The unique non-sub layout entries in this layout
17941794
nested : dict[tuple[int, int]], 2D object array
17951795
"""
1796-
unique_ids = set()
1796+
# make sure we preserve the user supplied order
1797+
unique_ids = cbook._OrderedSet()
17971798
nested = {}
17981799
for j, row in enumerate(layout):
17991800
for k, v in enumerate(row):
@@ -1804,7 +1805,7 @@ def _identify_keys_and_nested(layout):
18041805
else:
18051806
unique_ids.add(v)
18061807

1807-
return unique_ids, nested
1808+
return tuple(unique_ids), nested
18081809

18091810
def _do_layout(gs, layout, unique_ids, nested):
18101811
"""
@@ -1815,7 +1816,7 @@ def _do_layout(gs, layout, unique_ids, nested):
18151816
gs : GridSpec
18161817
layout : 2D object array
18171818
The input converted to a 2D numpy array for this level.
1818-
unique_ids : set
1819+
unique_ids : tuple
18191820
The identified scalar labels at this level of nesting.
18201821
nested : dict[tuple[int, int]], 2D object array
18211822
The identified nested layouts, if any.
@@ -1828,38 +1829,74 @@ def _do_layout(gs, layout, unique_ids, nested):
18281829
rows, cols = layout.shape
18291830
output = dict()
18301831

1831-
# create the Axes at this level of nesting
1832+
# we need to merge together the Axes at this level and the axes
1833+
# in the (recursively) nested sub-layouts so that we can add
1834+
# them to the figure in the "natural" order if you were to
1835+
# ravel in c-order all of the Axes that will be created
1836+
#
1837+
# This will stash the upper left index of each object (axes or
1838+
# nested layout) at this level
1839+
this_level = dict()
1840+
1841+
# go through the unique keys,
18321842
for name in unique_ids:
1843+
# sort out where each axes starts/ends
18331844
indx = np.argwhere(layout == name)
18341845
start_row, start_col = np.min(indx, axis=0)
18351846
end_row, end_col = np.max(indx, axis=0) + 1
1847+
# and construct the slice object
18361848
slc = (slice(start_row, end_row), slice(start_col, end_col))
1837-
1849+
# some light error checking
18381850
if (layout[slc] != name).any():
18391851
raise ValueError(
18401852
f"While trying to layout\n{layout!r}\n"
18411853
f"we found that the label {name!r} specifies a "
18421854
"non-rectangular or non-contiguous area.")
1855+
# and stash this slice for later
1856+
this_level[(start_row, start_col)] = (name, slc, 'axes')
18431857

1844-
ax = self.add_subplot(
1845-
gs[slc], **{'label': str(name), **subplot_kw}
1846-
)
1847-
output[name] = ax
1848-
1849-
# do any sub-layouts
1858+
# do the same thing for the nested layouts (simpler because these
1859+
# can not be spans yet!)
18501860
for (j, k), nested_layout in nested.items():
1851-
rows, cols = nested_layout.shape
1852-
nested_output = _do_layout(
1853-
gs[j, k].subgridspec(rows, cols, **gridspec_kw),
1854-
nested_layout,
1855-
*_identify_keys_and_nested(nested_layout)
1856-
)
1857-
overlap = set(output) & set(nested_output)
1858-
if overlap:
1859-
raise ValueError(f"There are duplicate keys {overlap} "
1860-
f"between the outer layout\n{layout!r}\n"
1861-
f"and the nested layout\n{nested_layout}")
1862-
output.update(nested_output)
1861+
this_level[(j, k)] = (None, nested_layout, 'nested')
1862+
1863+
# now go through the things in this level and add them
1864+
# in order left-to-right top-to-bottom
1865+
for key in sorted(this_level):
1866+
name, arg, method = this_level[key]
1867+
# we are doing some hokey function dispatch here based
1868+
# on the 'method' string stashed above to sort out if this
1869+
# element is an axes or a nested layout.
1870+
if method == 'axes':
1871+
slc = arg
1872+
# add a single axes
1873+
if name in output:
1874+
raise ValueError(f"There are duplicate keys {name} "
1875+
f"in the layout\n{layout!r}")
1876+
ax = self.add_subplot(
1877+
gs[slc], **{'label': str(name), **subplot_kw}
1878+
)
1879+
output[name] = ax
1880+
elif method == 'nested':
1881+
nested_layout = arg
1882+
j, k = key
1883+
# recursively add the nested layout
1884+
rows, cols = nested_layout.shape
1885+
nested_output = _do_layout(
1886+
gs[j, k].subgridspec(rows, cols, **gridspec_kw),
1887+
nested_layout,
1888+
*_identify_keys_and_nested(nested_layout)
1889+
)
1890+
overlap = set(output) & set(nested_output)
1891+
if overlap:
1892+
raise ValueError(
1893+
f"There are duplicate keys {overlap} "
1894+
f"between the outer layout\n{layout!r}\n"
1895+
f"and the nested layout\n{nested_layout}"
1896+
)
1897+
output.update(nested_output)
1898+
else:
1899+
raise RuntimeError("This should never happen")
18631900
return output
18641901

18651902
layout = _make_array(layout)

lib/matplotlib/tests/test_figure.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,28 @@ def test_hashable_keys(self, fig_test, fig_ref):
861861
fig_test.subplot_mosaic([[object(), object()]])
862862
fig_ref.subplot_mosaic([["A", "B"]])
863863

864+
@pytest.mark.parametrize('str_pattern',
865+
['abc', 'cab', 'bca', 'cba', 'acb', 'bac'])
866+
def test_user_order(self, str_pattern):
867+
fig = plt.figure()
868+
ax_dict = fig.subplot_mosaic(str_pattern)
869+
assert list(str_pattern) == list(ax_dict)
870+
assert list(fig.axes) == list(ax_dict.values())
871+
872+
def test_nested_user_order(self):
873+
layout = [
874+
["A", [["B", "C"],
875+
["D", "E"]]],
876+
["F", "G"],
877+
[".", [["H", [["I"],
878+
["."]]]]]
879+
]
880+
881+
fig = plt.figure()
882+
ax_dict = fig.subplot_mosaic(layout)
883+
assert list(ax_dict) == list("ABCDEFGHI")
884+
assert list(fig.axes) == list(ax_dict.values())
885+
864886

865887
def test_reused_gridspec():
866888
"""Test that these all use the same gridspec"""

0 commit comments

Comments
 (0)