Skip to content

Commit a1f0e66

Browse files
committed
Add ref argument to fig.legend and fig.colorbar, support 1D slicing, and intelligent placement
1 parent ab996e1 commit a1f0e66

File tree

2 files changed

+85
-12
lines changed

2 files changed

+85
-12
lines changed

ultraplot/figure.py

Lines changed: 85 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2594,6 +2594,8 @@ def colorbar(
25942594
"""
25952595
# Backwards compatibility
25962596
ax = kwargs.pop("ax", None)
2597+
ref = kwargs.pop("ref", None)
2598+
loc_ax = ref if ref is not None else ax
25972599
cax = kwargs.pop("cax", None)
25982600
if isinstance(values, maxes.Axes):
25992601
cax = _not_none(cax_positional=values, cax=cax)
@@ -2613,20 +2615,96 @@ def colorbar(
26132615
with context._state_context(cax, _internal_call=True): # do not wrap pcolor
26142616
cb = super().colorbar(mappable, cax=cax, **kwargs)
26152617
# Axes panel colorbar
2616-
elif ax is not None:
2618+
elif loc_ax is not None:
26172619
# Check if span parameters are provided
26182620
has_span = _not_none(span, row, col, rows, cols) is not None
26192621

2622+
# Infer span from loc_ax if it is a list and no span provided
2623+
if (
2624+
not has_span
2625+
and np.iterable(loc_ax)
2626+
and not isinstance(loc_ax, (str, maxes.Axes))
2627+
):
2628+
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
2629+
side = (
2630+
loc_trans
2631+
if loc_trans in ("left", "right", "top", "bottom")
2632+
else None
2633+
)
2634+
2635+
if side:
2636+
r_min, r_max = float("inf"), float("-inf")
2637+
c_min, c_max = float("inf"), float("-inf")
2638+
valid_ax = False
2639+
for axi in loc_ax:
2640+
if not hasattr(axi, "get_subplotspec"):
2641+
continue
2642+
ss = axi.get_subplotspec().get_topmost_subplotspec()
2643+
r1, r2, c1, c2 = ss._get_rows_columns()
2644+
r_min = min(r_min, r1)
2645+
r_max = max(r_max, r2)
2646+
c_min = min(c_min, c1)
2647+
c_max = max(c_max, c2)
2648+
valid_ax = True
2649+
2650+
if valid_ax:
2651+
if side in ("left", "right"):
2652+
rows = (r_min + 1, r_max + 1)
2653+
else:
2654+
cols = (c_min + 1, c_max + 1)
2655+
has_span = True
2656+
26202657
# Extract a single axes from array if span is provided
26212658
# Otherwise, pass the array as-is for normal colorbar behavior
2622-
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
2623-
try:
2624-
ax_single = next(iter(ax))
2659+
if (
2660+
has_span
2661+
and np.iterable(loc_ax)
2662+
and not isinstance(loc_ax, (str, maxes.Axes))
2663+
):
2664+
# Pick the best axis to anchor to based on the colorbar side
2665+
loc_trans = _translate_loc(loc, "colorbar", default=rc["colorbar.loc"])
2666+
side = (
2667+
loc_trans
2668+
if loc_trans in ("left", "right", "top", "bottom")
2669+
else None
2670+
)
2671+
2672+
best_ax = None
2673+
best_coord = float("-inf")
2674+
2675+
# If side is determined, search for the edge axis
2676+
if side:
2677+
for axi in loc_ax:
2678+
if not hasattr(axi, "get_subplotspec"):
2679+
continue
2680+
ss = axi.get_subplotspec().get_topmost_subplotspec()
2681+
r1, r2, c1, c2 = ss._get_rows_columns()
2682+
2683+
if side == "right":
2684+
val = c2 # Maximize column index
2685+
elif side == "left":
2686+
val = -c1 # Minimize column index
2687+
elif side == "bottom":
2688+
val = r2 # Maximize row index
2689+
elif side == "top":
2690+
val = -r1 # Minimize row index
2691+
else:
2692+
val = 0
2693+
2694+
if val > best_coord:
2695+
best_coord = val
2696+
best_ax = axi
26252697

2626-
except (TypeError, StopIteration):
2627-
ax_single = ax
2698+
# Fallback to first axis
2699+
if best_ax is None:
2700+
try:
2701+
ax_single = next(iter(loc_ax))
2702+
except (TypeError, StopIteration):
2703+
ax_single = loc_ax
2704+
else:
2705+
ax_single = best_ax
26282706
else:
2629-
ax_single = ax
2707+
ax_single = loc_ax
26302708

26312709
# Pass span parameters through to axes colorbar
26322710
cb = ax_single.colorbar(

ultraplot/tests/test_legend.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,6 @@ def test_legend_explicit_handles_labels_override_auto_collection():
531531
assert leg.get_texts()[0].get_text() == "custom_label"
532532

533533

534-
import numpy as np
535-
536-
import ultraplot as uplt
537-
538-
539534
def test_legend_ref_argument():
540535
"""Test using 'ref' to decouple legend location from content axes."""
541536
fig, axs = uplt.subplots(nrows=2, ncols=2)

0 commit comments

Comments
 (0)