Skip to content

Commit c6a3a52

Browse files
committed
Fix legend span inference with panels (#469)
* Fix legend span inference with panels Legend span inference used panel-inflated indices after prior legends added panel rows/cols, yielding invalid gridspec indices for list refs. Decode subplot indices to non-panel grid before computing span and add regression tests for multi-legend ordering. * Restore tests * Document legend span decode fallback Add a brief note that decoding panel indices can fail for panel or nested subplot specs, so we fall back to raw indices. * Add legend span/selection regression tests Cover best-axis selection for left/right/top/bottom and the decode-index fallback path to raise coverage around Figure.legend panel inference. * Extend legend coverage for edge ref handling Add tests that cover span inference with invalid ref entries, best-axis fallback on inset locations, and the empty-iterable ref fallback path.
1 parent eaa88b3 commit c6a3a52

File tree

2 files changed

+162
-1
lines changed

2 files changed

+162
-1
lines changed

ultraplot/figure.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2644,6 +2644,14 @@ def colorbar(
26442644
continue
26452645
ss = ss.get_topmost_subplotspec()
26462646
r1, r2, c1, c2 = ss._get_rows_columns()
2647+
gs = ss.get_gridspec()
2648+
if gs is not None:
2649+
try:
2650+
r1, r2 = gs._decode_indices(r1, r2, which="h")
2651+
c1, c2 = gs._decode_indices(c1, c2, which="w")
2652+
except ValueError:
2653+
# Non-panel decode can fail for panel or nested specs.
2654+
pass
26472655
r_min = min(r_min, r1)
26482656
r_max = max(r_max, r2)
26492657
c_min = min(c_min, c1)
@@ -2685,6 +2693,14 @@ def colorbar(
26852693
continue
26862694
ss = ss.get_topmost_subplotspec()
26872695
r1, r2, c1, c2 = ss._get_rows_columns()
2696+
gs = ss.get_gridspec()
2697+
if gs is not None:
2698+
try:
2699+
r1, r2 = gs._decode_indices(r1, r2, which="h")
2700+
c1, c2 = gs._decode_indices(c1, c2, which="w")
2701+
except ValueError:
2702+
# Non-panel decode can fail for panel or nested specs.
2703+
pass
26882704

26892705
if side == "right":
26902706
val = c2 # Maximize column index
@@ -2840,6 +2856,13 @@ def legend(
28402856
continue
28412857
ss = ss.get_topmost_subplotspec()
28422858
r1, r2, c1, c2 = ss._get_rows_columns()
2859+
gs = ss.get_gridspec()
2860+
if gs is not None:
2861+
try:
2862+
r1, r2 = gs._decode_indices(r1, r2, which="h")
2863+
c1, c2 = gs._decode_indices(c1, c2, which="w")
2864+
except ValueError:
2865+
pass
28432866
r_min = min(r_min, r1)
28442867
r_max = max(r_max, r2)
28452868
c_min = min(c_min, c1)
@@ -2881,6 +2904,13 @@ def legend(
28812904
continue
28822905
ss = ss.get_topmost_subplotspec()
28832906
r1, r2, c1, c2 = ss._get_rows_columns()
2907+
gs = ss.get_gridspec()
2908+
if gs is not None:
2909+
try:
2910+
r1, r2 = gs._decode_indices(r1, r2, which="h")
2911+
c1, c2 = gs._decode_indices(c1, c2, which="w")
2912+
except ValueError:
2913+
pass
28842914

28852915
if side == "right":
28862916
val = c2 # Maximize column index

ultraplot/tests/test_legend.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import pytest
88

99
import ultraplot as uplt
10+
from ultraplot.axes import Axes as UAxes
1011

1112

1213
@pytest.mark.mpl_image_compare
@@ -613,7 +614,137 @@ def test_ref_with_manual_axes_no_subplotspec():
613614
ax1 = fig.add_axes([0.1, 0.1, 0.4, 0.4])
614615
ax2 = fig.add_axes([0.5, 0.1, 0.4, 0.4])
615616
ax1.plot([0, 1], [0, 1], label="line")
616-
617617
# ref=[ax1, ax2]. loc='upper right' (inset).
618618
leg = fig.legend(ref=[ax1, ax2], loc="upper right")
619619
assert leg is not None
620+
621+
622+
def _decode_panel_span(panel_ax, axis):
623+
ss = panel_ax.get_subplotspec().get_topmost_subplotspec()
624+
r1, r2, c1, c2 = ss._get_rows_columns()
625+
gs = ss.get_gridspec()
626+
if axis == "rows":
627+
r1, r2 = gs._decode_indices(r1, r2, which="h")
628+
return int(r1), int(r2)
629+
if axis == "cols":
630+
c1, c2 = gs._decode_indices(c1, c2, which="w")
631+
return int(c1), int(c2)
632+
raise ValueError(f"Unknown axis {axis!r}.")
633+
634+
635+
def _anchor_axis(ref):
636+
if np.iterable(ref) and not isinstance(ref, (str, UAxes)):
637+
return next(iter(ref))
638+
return ref
639+
640+
641+
@pytest.mark.parametrize(
642+
"first_loc, first_ref, second_loc, second_ref, span_axis",
643+
[
644+
("b", lambda axs: axs[0], "r", lambda axs: axs[:, 1], "rows"),
645+
("r", lambda axs: axs[:, 2], "b", lambda axs: axs[1, :], "cols"),
646+
("t", lambda axs: axs[2], "l", lambda axs: axs[:, 0], "rows"),
647+
("l", lambda axs: axs[:, 0], "t", lambda axs: axs[1, :], "cols"),
648+
],
649+
)
650+
def test_legend_span_inference_with_multi_panels(
651+
first_loc, first_ref, second_loc, second_ref, span_axis
652+
):
653+
fig, axs = uplt.subplots(nrows=3, ncols=3)
654+
axs.plot([0, 1], [0, 1], label="line")
655+
656+
fig.legend(ref=first_ref(axs), loc=first_loc)
657+
fig.legend(ref=second_ref(axs), loc=second_loc)
658+
659+
side_map = {"l": "left", "r": "right", "t": "top", "b": "bottom"}
660+
anchor = _anchor_axis(second_ref(axs))
661+
panel_ax = anchor._panel_dict[side_map[second_loc]][-1]
662+
span = _decode_panel_span(panel_ax, span_axis)
663+
assert span == (0, 2)
664+
665+
666+
def test_legend_best_axis_selection_right_left():
667+
fig, axs = uplt.subplots(nrows=1, ncols=3)
668+
axs.plot([0, 1], [0, 1], label="line")
669+
ref = [axs[0, 0], axs[0, 2]]
670+
671+
fig.legend(ref=ref, loc="r", rows=1)
672+
assert len(axs[0, 2]._panel_dict["right"]) == 1
673+
assert len(axs[0, 0]._panel_dict["right"]) == 0
674+
675+
fig.legend(ref=ref, loc="l", rows=1)
676+
assert len(axs[0, 0]._panel_dict["left"]) == 1
677+
assert len(axs[0, 2]._panel_dict["left"]) == 0
678+
679+
680+
def test_legend_best_axis_selection_top_bottom():
681+
fig, axs = uplt.subplots(nrows=2, ncols=1)
682+
axs.plot([0, 1], [0, 1], label="line")
683+
ref = [axs[0, 0], axs[1, 0]]
684+
685+
fig.legend(ref=ref, loc="t", cols=1)
686+
assert len(axs[0, 0]._panel_dict["top"]) == 1
687+
assert len(axs[1, 0]._panel_dict["top"]) == 0
688+
689+
fig.legend(ref=ref, loc="b", cols=1)
690+
assert len(axs[1, 0]._panel_dict["bottom"]) == 1
691+
assert len(axs[0, 0]._panel_dict["bottom"]) == 0
692+
693+
694+
def test_legend_span_decode_fallback(monkeypatch):
695+
fig, axs = uplt.subplots(nrows=2, ncols=2)
696+
axs.plot([0, 1], [0, 1], label="line")
697+
ref = axs[:, 0]
698+
699+
gs = axs[0, 0].get_subplotspec().get_topmost_subplotspec().get_gridspec()
700+
701+
def _raise_decode(*args, **kwargs):
702+
raise ValueError("forced")
703+
704+
monkeypatch.setattr(gs, "_decode_indices", _raise_decode)
705+
leg = fig.legend(ref=ref, loc="r")
706+
assert leg is not None
707+
708+
709+
def test_legend_span_inference_skips_invalid_ref_axes():
710+
class DummyNoSpec:
711+
pass
712+
713+
class DummyNullSpec:
714+
def get_subplotspec(self):
715+
return None
716+
717+
fig, axs = uplt.subplots(nrows=1, ncols=2)
718+
axs[0].plot([0, 1], [0, 1], label="line")
719+
ref = [DummyNoSpec(), DummyNullSpec(), axs[0]]
720+
721+
leg = fig.legend(ax=axs[0], ref=ref, loc="r")
722+
assert leg is not None
723+
assert len(axs[0]._panel_dict["right"]) == 1
724+
725+
726+
def test_legend_best_axis_fallback_with_inset_loc():
727+
fig, axs = uplt.subplots(nrows=1, ncols=2)
728+
axs.plot([0, 1], [0, 1], label="line")
729+
730+
leg = fig.legend(ref=axs, loc="upper left", rows=1)
731+
assert leg is not None
732+
733+
734+
def test_legend_best_axis_fallback_empty_iterable_ref():
735+
class LegendProxy:
736+
def __init__(self, ax):
737+
self._ax = ax
738+
739+
def __iter__(self):
740+
return iter(())
741+
742+
def legend(self, *args, **kwargs):
743+
return self._ax.legend(*args, **kwargs)
744+
745+
fig, ax = uplt.subplots()
746+
ax.plot([0, 1], [0, 1], label="line")
747+
proxy = LegendProxy(ax)
748+
749+
leg = fig.legend(ref=proxy, loc="upper left", rows=1)
750+
assert leg is not None

0 commit comments

Comments
 (0)