Skip to content

Commit d7ec9ea

Browse files
authored
Fix SubplotGrid indexing and enhance legend placement with 'ref' argument (#461)
* Fix SubplotGrid indexing and allow legend placement decoupling * Add ref argument to fig.legend, support 1D slicing, and intelligent placement inference * Add ref argument to fig.legend and fig.colorbar, support 1D slicing, intelligent placement, and robust checks * Remove xdist from image compare
1 parent d31f165 commit d7ec9ea

File tree

5 files changed

+408
-23
lines changed

5 files changed

+408
-23
lines changed

docs/colorbars_legends.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -469,3 +469,44 @@
469469
ax = axs[1]
470470
ax.legend(hs2, loc="b", ncols=3, center=True, title="centered rows")
471471
axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Legend formatting demo")
472+
# %% [raw] raw_mimetype="text/restructuredtext"
473+
# .. _ug_guides_decouple:
474+
#
475+
# Decoupling legend content and location
476+
# --------------------------------------
477+
#
478+
# Sometimes you may want to generate a legend using handles from specific axes
479+
# but place it relative to other axes. In UltraPlot, you can achieve this by passing
480+
# both the `ax` and `ref` keywords to :func:`~ultraplot.figure.Figure.legend`
481+
# (or :func:`~ultraplot.figure.Figure.colorbar`). The `ax` keyword specifies the
482+
# axes used to generate the legend handles, while the `ref` keyword specifies the
483+
# reference axes used to determine the legend location.
484+
#
485+
# For example, to draw a legend based on the handles in the second row of subplots
486+
# but place it below the first row of subplots, you can use
487+
# ``fig.legend(ax=axs[1, :], ref=axs[0, :], loc='bottom')``. If ``ref`` is a list
488+
# of axes, UltraPlot intelligently infers the span (width or height) and anchors
489+
# the legend to the appropriate outer edge (e.g., the bottom-most axis for ``loc='bottom'``
490+
# or the right-most axis for ``loc='right'``).
491+
492+
# %%
493+
import numpy as np
494+
495+
import ultraplot as uplt
496+
497+
fig, axs = uplt.subplots(nrows=2, ncols=2, refwidth=2, share=False)
498+
axs.format(abc="A.", suptitle="Decoupled legend location demo")
499+
500+
# Plot data on all axes
501+
state = np.random.RandomState(51423)
502+
data = (state.rand(20, 4) - 0.5).cumsum(axis=0)
503+
for ax in axs:
504+
ax.plot(data, cycle="mplotcolors", labels=list("abcd"))
505+
506+
# Legend 1: Content from Row 2 (ax=axs[1, :]), Location below Row 1 (ref=axs[0, :])
507+
# This places a legend describing the bottom row data underneath the top row.
508+
fig.legend(ax=axs[1, :], ref=axs[0, :], loc="bottom", title="Data from Row 2")
509+
510+
# Legend 2: Content from Row 1 (ax=axs[0, :]), Location below Row 2 (ref=axs[1, :])
511+
# This places a legend describing the top row data underneath the bottom row.
512+
fig.legend(ax=axs[0, :], ref=axs[1, :], loc="bottom", title="Data from Row 1")

ultraplot/figure.py

Lines changed: 214 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,8 @@ def colorbar(
25942594
"""
25952595
# Backwards compatibility
25962596
ax = kwargs.pop("ax", None)
2597+
ref = kwargs.pop("ref", None)
2598+
loc_ax = ref if ref is not None else ax
25972599
cax = kwargs.pop("cax", None)
25982600
if isinstance(values, maxes.Axes):
25992601
cax = _not_none(cax_positional=values, cax=cax)
@@ -2613,20 +2615,102 @@ def colorbar(
26132615
with context._state_context(cax, _internal_call=True): # do not wrap pcolor
26142616
cb = super().colorbar(mappable, cax=cax, **kwargs)
26152617
# Axes panel colorbar
2616-
elif ax is not None:
2618+
elif loc_ax is not None:
26172619
# Check if span parameters are provided
26182620
has_span = _not_none(span, row, col, rows, cols) is not None
26192621

2622+
# Infer span from loc_ax if it is a list and no span provided
2623+
if (
2624+
not has_span
2625+
and np.iterable(loc_ax)
2626+
and not isinstance(loc_ax, (str, maxes.Axes))
2627+
):
2628+
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
2629+
side = (
2630+
loc_trans
2631+
if loc_trans in ("left", "right", "top", "bottom")
2632+
else None
2633+
)
2634+
2635+
if side:
2636+
r_min, r_max = float("inf"), float("-inf")
2637+
c_min, c_max = float("inf"), float("-inf")
2638+
valid_ax = False
2639+
for axi in loc_ax:
2640+
if not hasattr(axi, "get_subplotspec"):
2641+
continue
2642+
ss = axi.get_subplotspec()
2643+
if ss is None:
2644+
continue
2645+
ss = ss.get_topmost_subplotspec()
2646+
r1, r2, c1, c2 = ss._get_rows_columns()
2647+
r_min = min(r_min, r1)
2648+
r_max = max(r_max, r2)
2649+
c_min = min(c_min, c1)
2650+
c_max = max(c_max, c2)
2651+
valid_ax = True
2652+
2653+
if valid_ax:
2654+
if side in ("left", "right"):
2655+
rows = (r_min + 1, r_max + 1)
2656+
else:
2657+
cols = (c_min + 1, c_max + 1)
2658+
has_span = True
2659+
26202660
# Extract a single axes from array if span is provided
26212661
# Otherwise, pass the array as-is for normal colorbar behavior
2622-
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
2623-
try:
2624-
ax_single = next(iter(ax))
2662+
if (
2663+
has_span
2664+
and np.iterable(loc_ax)
2665+
and not isinstance(loc_ax, (str, maxes.Axes))
2666+
):
2667+
# Pick the best axis to anchor to based on the colorbar side
2668+
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
2669+
side = (
2670+
loc_trans
2671+
if loc_trans in ("left", "right", "top", "bottom")
2672+
else None
2673+
)
26252674

2626-
except (TypeError, StopIteration):
2627-
ax_single = ax
2675+
best_ax = None
2676+
best_coord = float("-inf")
2677+
2678+
# If side is determined, search for the edge axis
2679+
if side:
2680+
for axi in loc_ax:
2681+
if not hasattr(axi, "get_subplotspec"):
2682+
continue
2683+
ss = axi.get_subplotspec()
2684+
if ss is None:
2685+
continue
2686+
ss = ss.get_topmost_subplotspec()
2687+
r1, r2, c1, c2 = ss._get_rows_columns()
2688+
2689+
if side == "right":
2690+
val = c2 # Maximize column index
2691+
elif side == "left":
2692+
val = -c1 # Minimize column index
2693+
elif side == "bottom":
2694+
val = r2 # Maximize row index
2695+
elif side == "top":
2696+
val = -r1 # Minimize row index
2697+
else:
2698+
val = 0
2699+
2700+
if val > best_coord:
2701+
best_coord = val
2702+
best_ax = axi
2703+
2704+
# Fallback to first axis
2705+
if best_ax is None:
2706+
try:
2707+
ax_single = next(iter(loc_ax))
2708+
except (TypeError, StopIteration):
2709+
ax_single = loc_ax
2710+
else:
2711+
ax_single = best_ax
26282712
else:
2629-
ax_single = ax
2713+
ax_single = loc_ax
26302714

26312715
# Pass span parameters through to axes colorbar
26322716
cb = ax_single.colorbar(
@@ -2700,27 +2784,136 @@ def legend(
27002784
matplotlib.axes.Axes.legend
27012785
"""
27022786
ax = kwargs.pop("ax", None)
2787+
ref = kwargs.pop("ref", None)
2788+
loc_ax = ref if ref is not None else ax
2789+
27032790
# Axes panel legend
2704-
if ax is not None:
2791+
if loc_ax is not None:
2792+
content_ax = ax if ax is not None else loc_ax
27052793
# Check if span parameters are provided
27062794
has_span = _not_none(span, row, col, rows, cols) is not None
2707-
# Extract a single axes from array if span is provided
2708-
# Otherwise, pass the array as-is for normal legend behavior
2709-
# Automatically collect handles and labels from spanned axes if not provided
2710-
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
2711-
# Auto-collect handles and labels if not explicitly provided
2712-
if handles is None and labels is None:
2713-
handles, labels = [], []
2714-
for axi in ax:
2795+
2796+
# Automatically collect handles and labels from content axes if not provided
2797+
# Case 1: content_ax is a list (we must auto-collect)
2798+
# Case 2: content_ax != loc_ax (we must auto-collect because loc_ax.legend won't find content_ax handles)
2799+
must_collect = (
2800+
np.iterable(content_ax)
2801+
and not isinstance(content_ax, (str, maxes.Axes))
2802+
) or (content_ax is not loc_ax)
2803+
2804+
if must_collect and handles is None and labels is None:
2805+
handles, labels = [], []
2806+
# Handle list of axes
2807+
if np.iterable(content_ax) and not isinstance(
2808+
content_ax, (str, maxes.Axes)
2809+
):
2810+
for axi in content_ax:
27152811
h, l = axi.get_legend_handles_labels()
27162812
handles.extend(h)
27172813
labels.extend(l)
2718-
try:
2719-
ax_single = next(iter(ax))
2720-
except (TypeError, StopIteration):
2721-
ax_single = ax
2814+
# Handle single axis
2815+
else:
2816+
handles, labels = content_ax.get_legend_handles_labels()
2817+
2818+
# Infer span from loc_ax if it is a list and no span provided
2819+
if (
2820+
not has_span
2821+
and np.iterable(loc_ax)
2822+
and not isinstance(loc_ax, (str, maxes.Axes))
2823+
):
2824+
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
2825+
side = (
2826+
loc_trans
2827+
if loc_trans in ("left", "right", "top", "bottom")
2828+
else None
2829+
)
2830+
2831+
if side:
2832+
r_min, r_max = float("inf"), float("-inf")
2833+
c_min, c_max = float("inf"), float("-inf")
2834+
valid_ax = False
2835+
for axi in loc_ax:
2836+
if not hasattr(axi, "get_subplotspec"):
2837+
continue
2838+
ss = axi.get_subplotspec()
2839+
if ss is None:
2840+
continue
2841+
ss = ss.get_topmost_subplotspec()
2842+
r1, r2, c1, c2 = ss._get_rows_columns()
2843+
r_min = min(r_min, r1)
2844+
r_max = max(r_max, r2)
2845+
c_min = min(c_min, c1)
2846+
c_max = max(c_max, c2)
2847+
valid_ax = True
2848+
2849+
if valid_ax:
2850+
if side in ("left", "right"):
2851+
rows = (r_min + 1, r_max + 1)
2852+
else:
2853+
cols = (c_min + 1, c_max + 1)
2854+
has_span = True
2855+
2856+
# Extract a single axes from array if span is provided (or if ref is a list)
2857+
# Otherwise, pass the array as-is for normal legend behavior (only if loc_ax is list)
2858+
if (
2859+
has_span
2860+
and np.iterable(loc_ax)
2861+
and not isinstance(loc_ax, (str, maxes.Axes))
2862+
):
2863+
# Pick the best axis to anchor to based on the legend side
2864+
loc_trans = _translate_loc(loc, "legend", default=rc["legend.loc"])
2865+
side = (
2866+
loc_trans
2867+
if loc_trans in ("left", "right", "top", "bottom")
2868+
else None
2869+
)
2870+
2871+
best_ax = None
2872+
best_coord = float("-inf")
2873+
2874+
# If side is determined, search for the edge axis
2875+
if side:
2876+
for axi in loc_ax:
2877+
if not hasattr(axi, "get_subplotspec"):
2878+
continue
2879+
ss = axi.get_subplotspec()
2880+
if ss is None:
2881+
continue
2882+
ss = ss.get_topmost_subplotspec()
2883+
r1, r2, c1, c2 = ss._get_rows_columns()
2884+
2885+
if side == "right":
2886+
val = c2 # Maximize column index
2887+
elif side == "left":
2888+
val = -c1 # Minimize column index
2889+
elif side == "bottom":
2890+
val = r2 # Maximize row index
2891+
elif side == "top":
2892+
val = -r1 # Minimize row index
2893+
else:
2894+
val = 0
2895+
2896+
if val > best_coord:
2897+
best_coord = val
2898+
best_ax = axi
2899+
2900+
# Fallback to first axis if no best axis found (or side is None)
2901+
if best_ax is None:
2902+
try:
2903+
ax_single = next(iter(loc_ax))
2904+
except (TypeError, StopIteration):
2905+
ax_single = loc_ax
2906+
else:
2907+
ax_single = best_ax
2908+
27222909
else:
2723-
ax_single = ax
2910+
ax_single = loc_ax
2911+
if isinstance(ax_single, list):
2912+
try:
2913+
ax_single = pgridspec.SubplotGrid(ax_single)
2914+
except ValueError:
2915+
ax_single = ax_single[0]
2916+
27242917
leg = ax_single.legend(
27252918
handles,
27262919
labels,

ultraplot/gridspec.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ def _encode_indices(self, *args, which=None, panel=False):
425425
nums = []
426426
idxs = self._get_indices(which=which, panel=panel)
427427
for arg in args:
428+
if isinstance(arg, (list, np.ndarray)):
429+
try:
430+
nums.append([idxs[int(i)] for i in arg])
431+
except (IndexError, TypeError):
432+
raise ValueError(f"Invalid gridspec index {arg}.")
433+
continue
428434
try:
429435
nums.append(idxs[arg])
430436
except (IndexError, TypeError):
@@ -1612,10 +1618,13 @@ def __getitem__(self, key):
16121618
>>> axs[:, 0] # a SubplotGrid containing the subplots in the first column
16131619
"""
16141620
# Allow 1D list-like indexing
1615-
if isinstance(key, int):
1621+
if isinstance(key, (Integral, np.integer)):
16161622
return list.__getitem__(self, key)
16171623
elif isinstance(key, slice):
16181624
return SubplotGrid(list.__getitem__(self, key))
1625+
elif isinstance(key, (list, np.ndarray)):
1626+
# NOTE: list.__getitem__ does not support numpy integers
1627+
return SubplotGrid([list.__getitem__(self, int(i)) for i in key])
16191628

16201629
# Allow 2D array-like indexing
16211630
# NOTE: We assume this is a 2D array of subplots, because this is

0 commit comments

Comments
 (0)