Skip to content

Commit 9cc88a7

Browse files
cvanelterenCopilot
andauthored
Fix: Corrects SubplotGrid indexing and sequential GeoAxes formatting (#357)
* fix grid indexing * add unittest * Update ultraplot/tests/test_geographic.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * update tests to reflect changes * fix indexing * add unittest that made docs fail * restore indexing * fix indexing * rm dead code * handle index error --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 580b685 commit 9cc88a7

File tree

4 files changed

+99
-48
lines changed

4 files changed

+99
-48
lines changed

ultraplot/axes/geo.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -805,15 +805,18 @@ def _to_label_array(arg, lon=True):
805805
array[4] = True # possibly toggle geo spine labels
806806
elif not any(isinstance(_, str) for _ in array):
807807
if len(array) == 1:
808-
array.append(False) # default is to label bottom or left
808+
array.append(None)
809809
if len(array) == 2:
810-
array = [False, False, *array] if lon else [*array, False, False]
810+
array = [None, None, *array] if lon else [*array, None, None]
811811
if len(array) == 4:
812-
b = any(array) if rc["grid.geolabels"] else False
813-
array.append(b) # possibly toggle geo spine labels
812+
b = (
813+
any(a for a in array if a is not None)
814+
if rc["grid.geolabels"]
815+
else None
816+
)
817+
array.append(b)
814818
if len(array) != 5:
815819
raise ValueError(f"Invald boolean label array length {len(array)}.")
816-
array = list(map(bool, array))
817820
else:
818821
raise ValueError(f"Invalid {which}label spec: {arg}.")
819822
return array
@@ -934,9 +937,13 @@ def format(
934937
# NOTE: Cartopy 0.18 and 0.19 inline labels require any of
935938
# top, bottom, left, or right to be toggled then ignores them.
936939
# Later versions of cartopy permit both or neither labels.
937-
labels = _not_none(labels, rc.find("grid.labels", context=True))
938-
lonlabels = _not_none(lonlabels, labels)
939-
latlabels = _not_none(latlabels, labels)
940+
if lonlabels is None and latlabels is None:
941+
labels = _not_none(labels, rc.find("grid.labels", context=True))
942+
lonlabels = labels
943+
latlabels = labels
944+
else:
945+
lonlabels = _not_none(lonlabels, labels)
946+
latlabels = _not_none(latlabels, labels)
940947
# Set the ticks
941948
self._toggle_ticks(lonlabels, "x")
942949
self._toggle_ticks(latlabels, "y")
@@ -1464,8 +1471,9 @@ def _toggle_gridliner_labels(
14641471
side_labels = _CartopyAxes._get_side_labels()
14651472
togglers = (labelleft, labelright, labelbottom, labeltop)
14661473
gl = self.gridlines_major
1474+
14671475
for toggle, side in zip(togglers, side_labels):
1468-
if getattr(gl, side) != toggle:
1476+
if toggle is not None:
14691477
setattr(gl, side, toggle)
14701478
if geo is not None: # only cartopy 0.20 supported but harmless
14711479
setattr(gl, "geo_labels", geo)
@@ -1760,6 +1768,7 @@ def _update_major_gridlines(
17601768
for side, lon, lat in zip(
17611769
"labelleft labelright labelbottom labeltop geo".split(), lonarray, latarray
17621770
):
1771+
sides[side] = None
17631772
if lon and lat:
17641773
sides[side] = True
17651774
elif lon:

ultraplot/gridspec.py

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,42 +1536,48 @@ def __getitem__(self, key):
15361536
>>> axs[1, 2] # the subplot in the second row, third column
15371537
>>> axs[:, 0] # a SubplotGrid containing the subplots in the first column
15381538
"""
1539-
if isinstance(key, tuple) and len(key) == 1:
1540-
key = key[0]
1541-
# List-style indexing
1542-
if isinstance(key, (Integral, slice)):
1543-
slices = isinstance(key, slice)
1544-
objs = list.__getitem__(self, key)
1545-
# Gridspec-style indexing
1546-
elif (
1547-
isinstance(key, tuple)
1548-
and len(key) == 2
1549-
and all(isinstance(ikey, (Integral, slice)) for ikey in key)
1550-
):
1551-
# WARNING: Permit no-op slicing of empty grids here
1552-
slices = any(isinstance(ikey, slice) for ikey in key)
1553-
objs = []
1554-
if self:
1555-
gs = self.gridspec
1556-
ss_key = gs._make_subplot_spec(key) # obfuscates panels
1557-
row1_key, col1_key = divmod(ss_key.num1, gs.ncols)
1558-
row2_key, col2_key = divmod(ss_key.num2, gs.ncols)
1559-
for ax in self:
1560-
ss = ax._get_topmost_axes().get_subplotspec().get_topmost_subplotspec()
1561-
row1, col1 = divmod(ss.num1, gs.ncols)
1562-
row2, col2 = divmod(ss.num2, gs.ncols)
1563-
inrow = row1_key <= row1 <= row2_key or row1_key <= row2 <= row2_key
1564-
incol = col1_key <= col1 <= col2_key or col1_key <= col2 <= col2_key
1565-
if inrow and incol:
1566-
objs.append(ax)
1567-
if not slices and len(objs) == 1: # accounts for overlapping subplots
1568-
objs = objs[0]
1569-
else:
1570-
raise IndexError(f"Invalid index {key!r}.")
1571-
if isinstance(objs, list):
1572-
return SubplotGrid(objs)
1573-
else:
1574-
return objs
1539+
# Allow 1D list-like indexing
1540+
if isinstance(key, int):
1541+
return list.__getitem__(self, key)
1542+
elif isinstance(key, slice):
1543+
return SubplotGrid(list.__getitem__(self, key))
1544+
1545+
# Allow 2D array-like indexing
1546+
# NOTE: We assume this is a 2D array of subplots, because this is
1547+
# how it is generated in the first place by ultraplot.figure().
1548+
# But it is possible to append subplots manually.
1549+
gs = self.gridspec
1550+
if gs is None:
1551+
raise IndexError(
1552+
f"{self.__class__.__name__} has no gridspec, cannot index with {key!r}."
1553+
)
1554+
# Build grid with None for empty slots
1555+
grid = np.full((gs.nrows_total, gs.ncols_total), None, dtype=object)
1556+
for ax in self:
1557+
spec = ax.get_subplotspec()
1558+
x1, x2, y1, y2 = spec._get_rows_columns(ncols=gs.ncols_total)
1559+
grid[x1 : x2 + 1, y1 : y2 + 1] = ax
1560+
1561+
new_key = []
1562+
for which, keyi in zip("hw", key):
1563+
try:
1564+
encoded_keyi = gs._encode_indices(keyi, which=which)
1565+
except:
1566+
raise IndexError(
1567+
f"Attempted to access {key=} for gridspec {grid.shape=}"
1568+
)
1569+
new_key.append(encoded_keyi)
1570+
xs, ys = new_key
1571+
objs = grid[xs, ys]
1572+
if hasattr(objs, "flat"):
1573+
objs = [obj for obj in objs.flat if obj is not None]
1574+
elif not isinstance(objs, list):
1575+
objs = [objs]
1576+
1577+
if len(objs) == 1:
1578+
return objs[0]
1579+
objs = [obj for obj in objs if obj is not None]
1580+
return SubplotGrid(objs)
15751581

15761582
def __setitem__(self, key, value):
15771583
"""

ultraplot/tests/test_geographic.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -314,8 +314,8 @@ def test_toggle_gridliner_labels():
314314
gl = ax[0].gridlines_major
315315

316316
assert gl.left_labels == False
317-
assert gl.right_labels == None # initially these are none
318-
assert gl.top_labels == None
317+
assert gl.right_labels == False
318+
assert gl.top_labels == False
319319
assert gl.bottom_labels == False
320320
ax[0]._toggle_gridliner_labels(labeltop=True)
321321
assert gl.top_labels == True
@@ -617,7 +617,7 @@ def test_cartesian_and_geo(rng):
617617
ax[0].pcolormesh(rng.random((10, 10)))
618618
ax[1].scatter(*rng.random((2, 100)))
619619
ax[0]._apply_axis_sharing()
620-
assert mocked.call_count == 1
620+
assert mocked.call_count == 2
621621
return fig
622622

623623

@@ -895,3 +895,26 @@ def test_imshow_with_and_without_transform(rng):
895895
ax[2].imshow(data, transform=uplt.axes.geo.ccrs.PlateCarree())
896896
ax.format(title=["LCC", "No transform", "PlateCarree"])
897897
return fig
898+
899+
900+
@pytest.mark.mpl_image_compare
901+
def test_grid_indexing_formatting(rng):
902+
"""
903+
Check if subplotgrid is correctly selecting
904+
the subplots based on non-shared axis formatting
905+
"""
906+
# See https://github.com/Ultraplot/UltraPlot/issues/356
907+
lon = np.arange(0, 360, 10)
908+
lat = np.arange(-60, 60 + 1, 10)
909+
data = rng.random((len(lat), len(lon)))
910+
911+
fig, axs = uplt.subplots(nrows=3, ncols=2, proj="cyl", share=0)
912+
axs.format(coast=True)
913+
914+
for ax in axs:
915+
m = ax.pcolor(lon, lat, data)
916+
ax.colorbar(m)
917+
918+
axs[-1, :].format(lonlabels=True)
919+
axs[:, 0].format(latlabels=True)
920+
return fig

ultraplot/tests/test_subplots.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,16 @@ def test_panel_sharing_top_right(layout):
314314
# The sharing axis is not showing any ticks
315315
assert ax[0]._is_ticklabel_on(dir) == False
316316
return fig
317+
318+
319+
@pytest.mark.mpl_image_compare
320+
def test_uneven_span_subplots(rng):
321+
fig = uplt.figure(refwidth=1, refnum=5, span=False)
322+
axs = fig.subplots([[1, 1, 2], [3, 4, 2], [3, 4, 5]], hratios=[2.2, 1, 1])
323+
axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Complex SubplotGrid")
324+
axs[0].format(ec="black", fc="gray1", lw=1.4)
325+
axs[1, 1:].format(fc="blush")
326+
axs[1, :1].format(fc="sky blue")
327+
axs[-1, -1].format(fc="gray4", grid=False)
328+
axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2)
329+
return fig

0 commit comments

Comments
 (0)