Skip to content

Commit 769c1c5

Browse files
committed
Fix SubplotGrid indexing and allow legend placement decoupling
1 parent 0cd86ba commit 769c1c5

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

ultraplot/figure.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,6 +1417,18 @@ def _add_axes_panel(
14171417
if span_override is not None:
14181418
kw["span_override"] = span_override
14191419

1420+
# Check for position override (row for horizontal panels, col for vertical panels)
1421+
pos_override = None
1422+
if side in ("left", "right"):
1423+
if _not_none(cols, col) is not None:
1424+
pos_override = _not_none(cols, col)
1425+
else:
1426+
if _not_none(rows, row) is not None:
1427+
pos_override = _not_none(rows, row)
1428+
1429+
if pos_override is not None:
1430+
kw["pos_override"] = pos_override
1431+
14201432
ss, share = gs._insert_panel_slot(side, ax, **kw)
14211433
# Guard: GeoAxes with non-rectilinear projections cannot share with panels
14221434
if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear():

ultraplot/gridspec.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ def _encode_indices(self, *args, which=None, panel=False):
425425
nums = []
426426
idxs = self._get_indices(which=which, panel=panel)
427427
for arg in args:
428+
if isinstance(arg, (list, np.ndarray)):
429+
try:
430+
nums.append([idxs[int(i)] for i in arg])
431+
except (IndexError, TypeError):
432+
raise ValueError(f"Invalid gridspec index {arg}.")
433+
continue
428434
try:
429435
nums.append(idxs[arg])
430436
except (IndexError, TypeError):
@@ -595,6 +601,7 @@ def _parse_panel_arg_with_span(
595601
side: str,
596602
ax: "paxes.Axes",
597603
span_override: Optional[Union[int, Tuple[int, int]]],
604+
pos_override: Optional[Union[int, Tuple[int, int]]] = None,
598605
) -> Tuple[str, int, slice]:
599606
"""
600607
Parse panel arg with span override. Uses ax for position, span for extent.
@@ -607,6 +614,8 @@ def _parse_panel_arg_with_span(
607614
The axes to position the panel relative to
608615
span_override : int or tuple
609616
The span extent (1-indexed like subplot numbers)
617+
pos_override : int or tuple, optional
618+
The row or column index (1-indexed like subplot numbers)
610619
611620
Returns
612621
-------
@@ -621,6 +630,20 @@ def _parse_panel_arg_with_span(
621630
ss = ax.get_subplotspec().get_topmost_subplotspec()
622631
row1, row2, col1, col2 = ss._get_rows_columns()
623632

633+
# Override axes position if requested
634+
if pos_override is not None:
635+
if isinstance(pos_override, Integral):
636+
pos1, pos2 = pos_override - 1, pos_override - 1
637+
else:
638+
pos_override = np.atleast_1d(pos_override)
639+
pos1, pos2 = pos_override[0] - 1, pos_override[-1] - 1
640+
641+
# NOTE: We only need the relevant coordinate (row or col)
642+
if side in ("left", "right"):
643+
col1, col2 = pos1, pos2
644+
else:
645+
row1, row2 = pos1, pos2
646+
624647
# Determine slot and index based on side
625648
slot = side[0]
626649
offset = len(ax._panel_dict[side]) + 1
@@ -663,6 +686,7 @@ def _insert_panel_slot(
663686
pad: Optional[Union[float, str]] = None,
664687
filled: bool = False,
665688
span_override: Optional[Union[int, Tuple[int, int]]] = None,
689+
pos_override: Optional[Union[int, Tuple[int, int]]] = None,
666690
):
667691
"""
668692
Insert a panel slot into the existing gridspec. The `side` is the panel side
@@ -676,7 +700,9 @@ def _insert_panel_slot(
676700
raise ValueError(f"Invalid side {side}.")
677701
# Use span override if provided
678702
if span_override is not None:
679-
slot, idx, span = self._parse_panel_arg_with_span(side, arg, span_override)
703+
slot, idx, span = self._parse_panel_arg_with_span(
704+
side, arg, span_override, pos_override=pos_override
705+
)
680706
else:
681707
slot, idx, span = self._parse_panel_arg(side, arg)
682708
pad = units(pad, "em", "in")
@@ -1612,10 +1638,13 @@ def __getitem__(self, key):
16121638
>>> axs[:, 0] # a SubplotGrid containing the subplots in the first column
16131639
"""
16141640
# Allow 1D list-like indexing
1615-
if isinstance(key, int):
1641+
if isinstance(key, (Integral, np.integer)):
16161642
return list.__getitem__(self, key)
16171643
elif isinstance(key, slice):
16181644
return SubplotGrid(list.__getitem__(self, key))
1645+
elif isinstance(key, (list, np.ndarray)):
1646+
# NOTE: list.__getitem__ does not support numpy integers
1647+
return SubplotGrid([list.__getitem__(self, int(i)) for i in key])
16191648

16201649
# Allow 2D array-like indexing
16211650
# NOTE: We assume this is a 2D array of subplots, because this is

0 commit comments

Comments
 (0)