Skip to content

Commit ab996e1

Browse files
committed
Add ref argument to fig.legend, support 1D slicing, and intelligent placement inference
1 parent 769c1c5 commit ab996e1

File tree

5 files changed

+249
-48
lines changed

5 files changed

+249
-48
lines changed

docs/colorbars_legends.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,44 @@
469469
ax = axs[1]
470470
ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows")
471471
axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo")
472+
# %% [raw] raw_mimetype="text/restructuredtext"
473+
# .. _ug_guides_decouple:
474+
#
475+
# Decoupling legend content and location
476+
# --------------------------------------
477+
#
478+
# Sometimes you may want to generate a legend using handles from specific axes
479+
# but place it relative to other axes. In UltraPlot, you can achieve this by passing
480+
# both the `ax` and `ref` keywords to :func:`~ultraplot.figure.Figure.legend`
481+
# (or :func:`~ultraplot.figure.Figure.colorbar`). The `ax` keyword specifies the
482+
# axes used to generate the legend handles, while the `ref` keyword specifies the
483+
# reference axes used to determine the legend location.
484+
#
485+
# For example, to draw a legend based on the handles in the second row of subplots
486+
# but place it below the first row of subplots, you can use
487+
# ``fig.legend(ax=axs[1, :], ref=axs[0, :], loc='bottom')``. If ``ref`` is a list
488+
# of axes, UltraPlot intelligently infers the span (width or height) and anchors
489+
# the legend to the appropriate outer edge (e.g., the bottom-most axis for ``loc='bottom'``
490+
# or the right-most axis for ``loc='right'``).
491+
492+
# %%
493+
import numpy as np
494+
495+
import ultraplot as uplt
496+
497+
fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2, share=False)
498+
axs.format(abc="A.", suptitle="Decoupled legend location demo")
499+
500+
# Plot data on all axes
501+
state = np.random.RandomState(51423)
502+
data = (state.rand(20, 4) - 0.5).cumsum(axis=0)
503+
for ax in axs:
504+
ax.plot(data, cycle="mplotcolors", labels=list("abcd"))
505+
506+
# Legend 1: Content from Row 2 (ax=axs[1, :]), Location below Row 1 (ref=axs[0, :])
507+
# This places a legend describing the bottom row data underneath the top row.
508+
fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom", title="Data from Row 2")
509+
510+
# Legend 2: Content from Row 1 (ax=axs[0, :]), Location below Row 2 (ref=axs[1, :])
511+
# This places a legend describing the top row data underneath the bottom row.
512+
fig.legend(ax=axs[0, :], ref=axs[1, :], loc="bottom", title="Data from Row 1")

ultraplot/figure.py

Lines changed: 112 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1417,18 +1417,6 @@ def _add_axes_panel(
14171417
if span_override is not None:
14181418
kw["span_override"] = span_override
14191419

1420-
# Check for position override (row for horizontal panels, col for vertical panels)
1421-
pos_override = None
1422-
if side in ("left", "right"):
1423-
if _not_none(cols, col) is not None:
1424-
pos_override = _not_none(cols, col)
1425-
else:
1426-
if _not_none(rows, row) is not None:
1427-
pos_override = _not_none(rows, row)
1428-
1429-
if pos_override is not None:
1430-
kw["pos_override"] = pos_override
1431-
14321420
ss, share = gs._insert_panel_slot(side, ax, **kw)
14331421
# Guard: GeoAxes with non-rectilinear projections cannot share with panels
14341422
if isinstance(ax, paxes.GeoAxes) and not ax._is_rectilinear():
@@ -2712,27 +2700,125 @@ def legend(
27122700
matplotlib.axes.Axes.legend
27132701
"""
27142702
ax = kwargs.pop("ax", None)
2703+
ref = kwargs.pop("ref", None)
2704+
loc_ax = ref if ref is not None else ax
2705+
27152706
# Axes panel legend
2716-
if ax is not None:
2707+
if loc_ax is not None:
2708+
content_ax = ax if ax is not None else loc_ax
27172709
# Check if span parameters are provided
27182710
has_span = _not_none(span, row, col, rows, cols) is not None
2719-
# Extract a single axes from array if span is provided
2720-
# Otherwise, pass the array as-is for normal legend behavior
2721-
# Automatically collect handles and labels from spanned axes if not provided
2722-
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
2723-
# Auto-collect handles and labels if not explicitly provided
2724-
if handles is None and labels is None:
2725-
handles, labels = [], []
2726-
for axi in ax:
2711+
2712+
# Automatically collect handles and labels from content axes if not provided
2713+
# Case 1: content_ax is a list (we must auto-collect)
2714+
# Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
2715+
must_collect = (
2716+
np.iterable(content_ax)
2717+
and not isinstance(content_ax, (str, maxes.Axes))
2718+
) or (content_ax is not loc_ax)
2719+
2720+
if must_collect and handles is None and labels is None:
2721+
handles, labels = [], []
2722+
# Handle list of axes
2723+
if np.iterable(content_ax) and not isinstance(
2724+
content_ax, (str, maxes.Axes)
2725+
):
2726+
for axi in content_ax:
27272727
h, l = axi.get_legend_handles_labels()
27282728
handles.extend(h)
27292729
labels.extend(l)
2730-
try:
2731-
ax_single = next(iter(ax))
2732-
except (TypeError, StopIteration):
2733-
ax_single = ax
2730+
# Handle single axis
2731+
else:
2732+
handles, labels = content_ax.get_legend_handles_labels()
2733+
2734+
# Infer span from loc_ax if it is a list and no span provided
2735+
if (
2736+
not has_span
2737+
and np.iterable(loc_ax)
2738+
and not isinstance(loc_ax, (str, maxes.Axes))
2739+
):
2740+
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
2741+
side = (
2742+
loc_trans
2743+
if loc_trans in ("left", "right", "top", "bottom")
2744+
else None
2745+
)
2746+
2747+
if side:
2748+
r_min, r_max = float("inf"), float("-inf")
2749+
c_min, c_max = float("inf"), float("-inf")
2750+
valid_ax = False
2751+
for axi in loc_ax:
2752+
if not hasattr(axi, "get_subplotspec"):
2753+
continue
2754+
ss = axi.get_subplotspec().get_topmost_subplotspec()
2755+
r1, r2, c1, c2 = ss._get_rows_columns()
2756+
r_min = min(r_min, r1)
2757+
r_max = max(r_max, r2)
2758+
c_min = min(c_min, c1)
2759+
c_max = max(c_max, c2)
2760+
valid_ax = True
2761+
2762+
if valid_ax:
2763+
if side in ("left", "right"):
2764+
rows = (r_min + 1, r_max + 1)
2765+
else:
2766+
cols = (c_min + 1, c_max + 1)
2767+
has_span = True
2768+
2769+
# Extract a single axes from array if span is provided (or if ref is a list)
2770+
# Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
2771+
if (
2772+
has_span
2773+
and np.iterable(loc_ax)
2774+
and not isinstance(loc_ax, (str, maxes.Axes))
2775+
):
2776+
# Pick the best axis to anchor to based on the legend side
2777+
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
2778+
side = (
2779+
loc_trans
2780+
if loc_trans in ("left", "right", "top", "bottom")
2781+
else None
2782+
)
2783+
2784+
best_ax = None
2785+
best_coord = float("-inf")
2786+
2787+
# If side is determined, search for the edge axis
2788+
if side:
2789+
for axi in loc_ax:
2790+
if not hasattr(axi, "get_subplotspec"):
2791+
continue
2792+
ss = axi.get_subplotspec().get_topmost_subplotspec()
2793+
r1, r2, c1, c2 = ss._get_rows_columns()
2794+
2795+
if side == "right":
2796+
val = c2 # Maximize column index
2797+
elif side == "left":
2798+
val = -c1 # Minimize column index
2799+
elif side == "bottom":
2800+
val = r2 # Maximize row index
2801+
elif side == "top":
2802+
val = -r1 # Minimize row index
2803+
else:
2804+
val = 0
2805+
2806+
if val > best_coord:
2807+
best_coord = val
2808+
best_ax = axi
2809+
2810+
# Fallback to first axis if no best axis found (or side is None)
2811+
if best_ax is None:
2812+
try:
2813+
ax_single = next(iter(loc_ax))
2814+
except (TypeError, StopIteration):
2815+
ax_single = loc_ax
2816+
else:
2817+
ax_single = best_ax
2818+
27342819
else:
2735-
ax_single = ax
2820+
ax_single = loc_ax
2821+
27362822
leg = ax_single.legend(
27372823
handles,
27382824
labels,

ultraplot/gridspec.py

Lines changed: 1 addition & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -601,7 +601,6 @@ def _parse_panel_arg_with_span(
601601
side: str,
602602
ax: "paxes.Axes",
603603
span_override: Optional[Union[int, Tuple[int, int]]],
604-
pos_override: Optional[Union[int, Tuple[int, int]]] = None,
605604
) -> Tuple[str, int, slice]:
606605
"""
607606
Parse panel arg with span override. Uses ax for position, span for extent.
@@ -614,8 +613,6 @@ def _parse_panel_arg_with_span(
614613
The axes to position the panel relative to
615614
span_override : int or tuple
616615
The span extent (1-indexed like subplot numbers)
617-
pos_override : int or tuple, optional
618-
The row or column index (1-indexed like subplot numbers)
619616
620617
Returns
621618
-------
@@ -630,20 +627,6 @@ def _parse_panel_arg_with_span(
630627
ss = ax.get_subplotspec().get_topmost_subplotspec()
631628
row1, row2, col1, col2 = ss._get_rows_columns()
632629

633-
# Override axes position if requested
634-
if pos_override is not None:
635-
if isinstance(pos_override, Integral):
636-
pos1, pos2 = pos_override - 1, pos_override - 1
637-
else:
638-
pos_override = np.atleast_1d(pos_override)
639-
pos1, pos2 = pos_override[0] - 1, pos_override[-1] - 1
640-
641-
# NOTE: We only need the relevant coordinate (row or col)
642-
if side in ("left", "right"):
643-
col1, col2 = pos1, pos2
644-
else:
645-
row1, row2 = pos1, pos2
646-
647630
# Determine slot and index based on side
648631
slot = side[0]
649632
offset = len(ax._panel_dict[side]) + 1
@@ -686,7 +669,6 @@ def _insert_panel_slot(
686669
pad: Optional[Union[float, str]] = None,
687670
filled: bool = False,
688671
span_override: Optional[Union[int, Tuple[int, int]]] = None,
689-
pos_override: Optional[Union[int, Tuple[int, int]]] = None,
690672
):
691673
"""
692674
Insert a panel slot into the existing gridspec. The `side` is the panel side
@@ -700,9 +682,7 @@ def _insert_panel_slot(
700682
raise ValueError(f"Invalid side {side}.")
701683
# Use span override if provided
702684
if span_override is not None:
703-
slot, idx, span = self._parse_panel_arg_with_span(
704-
side, arg, span_override, pos_override=pos_override
705-
)
685+
slot, idx, span = self._parse_panel_arg_with_span(side, arg, span_override)
706686
else:
707687
slot, idx, span = self._parse_panel_arg(side, arg)
708688
pad = units(pad, "em", "in")

ultraplot/tests/test_gridspec.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import ultraplot as uplt
21
import pytest
2+
3+
import ultraplot as uplt
34
from ultraplot.gridspec import SubplotGrid
45

56

@@ -72,3 +73,56 @@ def test_tight_layout_disabled():
7273
gs = ax.get_subplotspec().get_gridspec()
7374
with pytest.raises(RuntimeError):
7475
gs.tight_layout(fig)
76+
77+
78+
def test_gridspec_slicing():
79+
"""
80+
Test various slicing methods on SubplotGrid, including 1D list/array indexing.
81+
"""
82+
import numpy as np
83+
84+
fig, axs = uplt.subplots(nrows=4, ncols=4)
85+
86+
# Test 1D integer indexing
87+
assert axs[0].number == 1
88+
assert axs[15].number == 16
89+
90+
# Test 1D slice indexing
91+
subset = axs[0:2]
92+
assert isinstance(subset, SubplotGrid)
93+
assert len(subset) == 2
94+
assert subset[0].number == 1
95+
assert subset[1].number == 2
96+
97+
# Test 1D list indexing (Fix #1)
98+
subset_list = axs[[0, 5]]
99+
assert isinstance(subset_list, SubplotGrid)
100+
assert len(subset_list) == 2
101+
assert subset_list[0].number == 1
102+
assert subset_list[1].number == 6
103+
104+
# Test 1D array indexing
105+
subset_array = axs[np.array([0, 5])]
106+
assert isinstance(subset_array, SubplotGrid)
107+
assert len(subset_array) == 2
108+
assert subset_array[0].number == 1
109+
assert subset_array[1].number == 6
110+
111+
# Test 2D slicing (tuple of slices)
112+
# axs[0:2, :] -> Rows 0 and 1, all cols
113+
subset_2d = axs[0:2, :]
114+
assert isinstance(subset_2d, SubplotGrid)
115+
# 2 rows * 4 cols = 8 axes
116+
assert len(subset_2d) == 8
117+
118+
# Test 2D mixed slicing (list in one dim) (Fix #2 related to _encode_indices)
119+
# axs[[0, 1], :] -> Row indices 0 and 1, all cols
120+
subset_mixed = axs[[0, 1], :]
121+
assert isinstance(subset_mixed, SubplotGrid)
122+
assert len(subset_mixed) == 8
123+
124+
# Verify content
125+
# subset_mixed[0] -> Row 0, Col 0 -> Number 1
126+
# subset_mixed[4] -> Row 1, Col 0 -> Number 5 (since 4 cols per row)
127+
assert subset_mixed[0].number == 1
128+
assert subset_mixed[4].number == 5

ultraplot/tests/test_legend.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -529,3 +529,43 @@ def test_legend_explicit_handles_labels_override_auto_collection():
529529
assert leg is not None
530530
assert len(leg.get_texts()) == 1
531531
assert leg.get_texts()[0].get_text() == "custom_label"
532+
533+
534+
import numpy as np
535+
536+
import ultraplot as uplt
537+
538+
539+
def test_legend_ref_argument():
540+
"""Test using 'ref' to decouple legend location from content axes."""
541+
fig, axs = uplt.subplots(nrows=2, ncols=2)
542+
axs[0, 0].plot([], [], label="line1") # Row 0
543+
axs[1, 0].plot([], [], label="line2") # Row 1
544+
545+
# Place legend below Row 0 (axs[0, :]) using content from Row 1 (axs[1, :])
546+
leg = fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom")
547+
548+
assert leg is not None
549+
550+
# Should be a single legend because span is inferred from ref
551+
assert not isinstance(leg, tuple)
552+
553+
texts = [t.get_text() for t in leg.get_texts()]
554+
assert "line2" in texts
555+
assert "line1" not in texts
556+
557+
558+
def test_legend_ref_argument_no_ax():
559+
"""Test using 'ref' where 'ax' is implied to be 'ref'."""
560+
fig, axs = uplt.subplots(nrows=1, ncols=1)
561+
axs[0].plot([], [], label="line1")
562+
563+
# ref provided, ax=None. Should behave like ax=ref.
564+
leg = fig.legend(ref=axs[0], loc="bottom")
565+
assert leg is not None
566+
567+
# Should be a single legend
568+
assert not isinstance(leg, tuple)
569+
570+
texts = [t.get_text() for t in leg.get_texts()]
571+
assert "line1" in texts

0 commit comments

Comments
 (0)