Skip to content

Commit 438a952

Browse files
committed
Add ref argument to fig.legend for decoupled placement and support 1D slicing in SubplotGrid
1 parent 769c1c5 commit 438a952

File tree

4 files changed

+130
-46
lines changed

4 files changed

+130
-46
lines changed

ultraplot/figure.py

Lines changed: 36 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,18 +1417,6 @@ def _add_axes_panel(
14171417
if span_override is not None:
14181418
kw["span_override"] = span_override
14191419

1420-
# Check for position override (row for horizontal panels, col for vertical panels)
1421-
pos_override = None
1422-
if side in ("left", "right"):
1423-
if _not_none(cols, col) is not None:
1424-
pos_override = _not_none(cols, col)
1425-
else:
1426-
if _not_none(rows, row) is not None:
1427-
pos_override = _not_none(rows, row)
1428-
1429-
if pos_override is not None:
1430-
kw["pos_override"] = pos_override
1431-
14321420
ss, share = gs._insert_panel_slot(side, ax, **kw)
14331421
# Guard: GeoAxes with non-rectilinear projections cannot share with panels
14341422
if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear():
@@ -2712,27 +2700,51 @@ def legend(
27122700
matplotlib.axes.Axes.legend
27132701
"""
27142702
ax = kwargs.pop("ax", None)
2703+
ref = kwargs.pop("ref", None)
2704+
loc_ax = ref if ref is not None else ax
2705+
27152706
# Axes panel legend
2716-
if ax is not None:
2707+
if loc_ax is not None:
2708+
content_ax = ax if ax is not None else loc_ax
27172709
# Check if span parameters are provided
27182710
has_span = _not_none(span, row, col, rows, cols) is not None
2719-
# Extract a single axes from array if span is provided
2720-
# Otherwise, pass the array as-is for normal legend behavior
2721-
# Automatically collect handles and labels from spanned axes if not provided
2722-
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
2723-
# Auto-collect handles and labels if not explicitly provided
2724-
if handles is None and labels is None:
2725-
handles, labels = [], []
2726-
for axi in ax:
2711+
2712+
# Automatically collect handles and labels from content axes if not provided
2713+
# Case 1: content_ax is a list (we must auto-collect)
2714+
# Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
2715+
must_collect = (
2716+
np.iterable(content_ax)
2717+
and not isinstance(content_ax, (str, maxes.Axes))
2718+
) or (content_ax is not loc_ax)
2719+
2720+
if must_collect and handles is None and labels is None:
2721+
handles, labels = [], []
2722+
# Handle list of axes
2723+
if np.iterable(content_ax) and not isinstance(
2724+
content_ax, (str, maxes.Axes)
2725+
):
2726+
for axi in content_ax:
27272727
h, l = axi.get_legend_handles_labels()
27282728
handles.extend(h)
27292729
labels.extend(l)
2730+
# Handle single axis
2731+
else:
2732+
handles, labels = content_ax.get_legend_handles_labels()
2733+
2734+
# Extract a single axes from array if span is provided (or if ref is a list)
2735+
# Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
2736+
if (
2737+
has_span
2738+
and np.iterable(loc_ax)
2739+
and not isinstance(loc_ax, (str, maxes.Axes))
2740+
):
27302741
try:
2731-
ax_single = next(iter(ax))
2742+
ax_single = next(iter(loc_ax))
27322743
except (TypeError, StopIteration):
2733-
ax_single = ax
2744+
ax_single = loc_ax
27342745
else:
2735-
ax_single = ax
2746+
ax_single = loc_ax
2747+
27362748
leg = ax_single.legend(
27372749
handles,
27382750
labels,

ultraplot/gridspec.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,6 @@ def _parse_panel_arg_with_span(
601601
side: str,
602602
ax: "paxes.Axes",
603603
span_override: Optional[Union[int, Tuple[int, int]]],
604-
pos_override: Optional[Union[int, Tuple[int, int]]] = None,
605604
) -> Tuple[str, int, slice]:
606605
"""
607606
Parse panel arg with span override. Uses ax for position, span for extent.
@@ -614,8 +613,6 @@ def _parse_panel_arg_with_span(
614613
The axes to position the panel relative to
615614
span_override : int or tuple
616615
The span extent (1-indexed like subplot numbers)
617-
pos_override : int or tuple, optional
618-
The row or column index (1-indexed like subplot numbers)
619616
620617
Returns
621618
-------
@@ -630,20 +627,6 @@ def _parse_panel_arg_with_span(
630627
ss = ax.get_subplotspec().get_topmost_subplotspec()
631628
row1, row2, col1, col2 = ss._get_rows_columns()
632629

633-
# Override axes position if requested
634-
if pos_override is not None:
635-
if isinstance(pos_override, Integral):
636-
pos1, pos2 = pos_override - 1, pos_override - 1
637-
else:
638-
pos_override = np.atleast_1d(pos_override)
639-
pos1, pos2 = pos_override[0] - 1, pos_override[-1] - 1
640-
641-
# NOTE: We only need the relevant coordinate (row or col)
642-
if side in ("left", "right"):
643-
col1, col2 = pos1, pos2
644-
else:
645-
row1, row2 = pos1, pos2
646-
647630
# Determine slot and index based on side
648631
slot = side[0]
649632
offset = len(ax._panel_dict[side]) + 1
@@ -686,7 +669,6 @@ def _insert_panel_slot(
686669
pad: Optional[Union[float, str]] = None,
687670
filled: bool = False,
688671
span_override: Optional[Union[int, Tuple[int, int]]] = None,
689-
pos_override: Optional[Union[int, Tuple[int, int]]] = None,
690672
):
691673
"""
692674
Insert a panel slot into the existing gridspec. The `side` is the panel side
@@ -700,9 +682,7 @@ def _insert_panel_slot(
700682
raise ValueError(f"Invalid side {side}.")
701683
# Use span override if provided
702684
if span_override is not None:
703-
slot, idx, span = self._parse_panel_arg_with_span(
704-
side, arg, span_override, pos_override=pos_override
705-
)
685+
slot, idx, span = self._parse_panel_arg_with_span(side, arg, span_override)
706686
else:
707687
slot, idx, span = self._parse_panel_arg(side, arg)
708688
pad = units(pad, "em", "in")

ultraplot/tests/test_gridspec.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import ultraplot as uplt
21
import pytest
2+
3+
import ultraplot as uplt
34
from ultraplot.gridspec import SubplotGrid
45

56

@@ -72,3 +73,56 @@ def test_tight_layout_disabled():
7273
gs = ax.get_subplotspec().get_gridspec()
7374
with pytest.raises(RuntimeError):
7475
gs.tight_layout(fig)
76+
77+
78+
def test_gridspec_slicing():
79+
"""
80+
Test various slicing methods on SubplotGrid, including 1D list/array indexing.
81+
"""
82+
import numpy as np
83+
84+
fig, axs = uplt.subplots(nrows=4, ncols=4)
85+
86+
# Test 1D integer indexing
87+
assert axs[0].number == 1
88+
assert axs[15].number == 16
89+
90+
# Test 1D slice indexing
91+
subset = axs[0:2]
92+
assert isinstance(subset, SubplotGrid)
93+
assert len(subset) == 2
94+
assert subset[0].number == 1
95+
assert subset[1].number == 2
96+
97+
# Test 1D list indexing (Fix #1)
98+
subset_list = axs[[0, 5]]
99+
assert isinstance(subset_list, SubplotGrid)
100+
assert len(subset_list) == 2
101+
assert subset_list[0].number == 1
102+
assert subset_list[1].number == 6
103+
104+
# Test 1D array indexing
105+
subset_array = axs[np.array([0, 5])]
106+
assert isinstance(subset_array, SubplotGrid)
107+
assert len(subset_array) == 2
108+
assert subset_array[0].number == 1
109+
assert subset_array[1].number == 6
110+
111+
# Test 2D slicing (tuple of slices)
112+
# axs[0:2, :] -> Rows 0 and 1, all cols
113+
subset_2d = axs[0:2, :]
114+
assert isinstance(subset_2d, SubplotGrid)
115+
# 2 rows * 4 cols = 8 axes
116+
assert len(subset_2d) == 8
117+
118+
# Test 2D mixed slicing (list in one dim) (Fix #2 related to _encode_indices)
119+
# axs[[0, 1], :] -> Row indices 0 and 1, all cols
120+
subset_mixed = axs[[0, 1], :]
121+
assert isinstance(subset_mixed, SubplotGrid)
122+
assert len(subset_mixed) == 8
123+
124+
# Verify content
125+
# subset_mixed[0] -> Row 0, Col 0 -> Number 1
126+
# subset_mixed[4] -> Row 1, Col 0 -> Number 5 (since 4 cols per row)
127+
assert subset_mixed[0].number == 1
128+
assert subset_mixed[4].number == 5

ultraplot/tests/test_legend.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,41 @@ def test_legend_explicit_handles_labels_override_auto_collection():
529529
assert leg is not None
530530
assert len(leg.get_texts()) == 1
531531
assert leg.get_texts()[0].get_text() == "custom_label"
532+
import numpy as np
533+
534+
import ultraplot as uplt
535+
536+
537+
def test_legend_ref_argument():
538+
"""Test using 'ref' to decouple legend location from content axes."""
539+
fig, axs = uplt.subplots(nrows=2, ncols=2)
540+
axs[0, 0].plot([], [], label="line1") # Row 0
541+
axs[1, 0].plot([], [], label="line2") # Row 1
542+
543+
# Place legend below Row 0 (axs[0, :]) using content from Row 1 (axs[1, :])
544+
leg = fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom")
545+
546+
assert leg is not None
547+
548+
legs = leg if isinstance(leg, tuple) else (leg,)
549+
550+
for l in legs:
551+
texts = [t.get_text() for t in l.get_texts()]
552+
assert "line2" in texts
553+
assert "line1" not in texts
554+
555+
556+
def test_legend_ref_argument_no_ax():
557+
"""Test using 'ref' where 'ax' is implied to be 'ref'."""
558+
fig, axs = uplt.subplots(nrows=1, ncols=1)
559+
axs[0].plot([], [], label="line1")
560+
561+
# ref provided, ax=None. Should behave like ax=ref.
562+
leg = fig.legend(ref=axs[0], loc="bottom")
563+
assert leg is not None
564+
565+
legs = leg if isinstance(leg, tuple) else (leg,)
566+
567+
for l in legs:
568+
texts = [t.get_text() for t in l.get_texts()]
569+
assert "line1" in texts

0 commit comments

Comments
 (0)