Skip to content

Commit 4f2f8c0

Browse files
authored
added inference of labels for spanning legends (#447)
1 parent 5895f8a commit 4f2f8c0

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

ultraplot/figure.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2704,10 +2704,17 @@ def legend(
27042704
if ax is not None:
27052705
# Check if span parameters are provided
27062706
has_span = _not_none(span, row, col, rows, cols) is not None
2707-
27082707
# Extract a single axes from array if span is provided
27092708
# Otherwise, pass the array as-is for normal legend behavior
2709+
# Automatically collect handles and labels from spanned axes if not provided
27102710
if has_span and np.iterable(ax) and not isinstance(ax, (str, maxes.Axes)):
2711+
# Auto-collect handles and labels if not explicitly provided
2712+
if handles is None and labels is None:
2713+
handles, labels = [], []
2714+
for axi in ax:
2715+
h, l = axi.get_legend_handles_labels()
2716+
handles.extend(h)
2717+
labels.extend(l)
27112718
try:
27122719
ax_single = next(iter(ax))
27132720
except (TypeError, StopIteration):

ultraplot/tests/test_legend.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -483,3 +483,49 @@ def test_legend_multiple_sides_with_span():
483483
assert leg_top is not None
484484
assert leg_right is not None
485485
assert leg_left is not None
486+
487+
488+
def test_legend_auto_collect_handles_labels_with_span():
489+
"""Test automatic collection of handles and labels from multiple axes with span parameters."""
490+
491+
fig, axs = uplt.subplots(nrows=2, ncols=2)
492+
493+
# Create different plots in each subplot with labels
494+
axs[0, 0].plot([0, 1], [0, 1], label="line1")
495+
axs[0, 1].plot([0, 1], [1, 0], label="line2")
496+
axs[1, 0].scatter([0.5], [0.5], label="point1")
497+
axs[1, 1].scatter([0.5], [0.5], label="point2")
498+
499+
# Test automatic collection with span parameter (no explicit handles/labels)
500+
leg = fig.legend(ax=axs[0, :], span=(1, 2), loc="bottom")
501+
502+
# Verify legend was created and contains all handles/labels from both axes
503+
assert leg is not None
504+
assert len(leg.get_texts()) == 2 # Should have 2 labels (line1, line2)
505+
506+
# Test with rows parameter
507+
leg2 = fig.legend(ax=axs[:, 0], rows=(1, 2), loc="right")
508+
assert leg2 is not None
509+
assert len(leg2.get_texts()) == 2 # Should have 2 labels (line1, point1)
510+
511+
512+
def test_legend_explicit_handles_labels_override_auto_collection():
513+
"""Test that explicit handles/labels override auto-collection."""
514+
515+
fig, axs = uplt.subplots(nrows=1, ncols=2)
516+
517+
# Create plots with labels
518+
(h1,) = axs[0].plot([0, 1], [0, 1], label="auto_label1")
519+
(h2,) = axs[1].plot([0, 1], [1, 0], label="auto_label2")
520+
521+
# Test with explicit handles/labels (should override auto-collection)
522+
custom_handles = [h1]
523+
custom_labels = ["custom_label"]
524+
leg = fig.legend(
525+
ax=axs, span=(1, 2), loc="bottom", handles=custom_handles, labels=custom_labels
526+
)
527+
528+
# Verify legend uses explicit handles/labels, not auto-collected ones
529+
assert leg is not None
530+
assert len(leg.get_texts()) == 1
531+
assert leg.get_texts()[0].get_text() == "custom_label"

0 commit comments

Comments
 (0)