diff --git a/ultraplot/figure.py b/ultraplot/figure.py index d7f33f8e..e7887088 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -2644,6 +2644,14 @@ def colorbar( continue ss = ss.get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + # Non-panel decode can fail for panel or nested specs. + pass r_min = min(r_min, r1) r_max = max(r_max, r2) c_min = min(c_min, c1) @@ -2685,6 +2693,14 @@ def colorbar( continue ss = ss.get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + # Non-panel decode can fail for panel or nested specs. + pass if side == "right": val = c2 # Maximize column index @@ -2840,6 +2856,13 @@ def legend( continue ss = ss.get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + pass r_min = min(r_min, r1) r_max = max(r_max, r2) c_min = min(c_min, c1) @@ -2881,6 +2904,13 @@ def legend( continue ss = ss.get_topmost_subplotspec() r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if gs is not None: + try: + r1, r2 = gs._decode_indices(r1, r2, which="h") + c1, c2 = gs._decode_indices(c1, c2, which="w") + except ValueError: + pass if side == "right": val = c2 # Maximize column index diff --git a/ultraplot/tests/test_legend.py b/ultraplot/tests/test_legend.py index a37f2ff0..f9157ddd 100644 --- a/ultraplot/tests/test_legend.py +++ b/ultraplot/tests/test_legend.py @@ -7,6 +7,7 @@ import pytest import ultraplot as uplt +from ultraplot.axes import Axes as UAxes @pytest.mark.mpl_image_compare @@ -613,7 +614,137 @@ def test_ref_with_manual_axes_no_subplotspec(): ax1 = fig.add_axes([0.1, 0.1, 0.4, 0.4]) ax2 = fig.add_axes([0.5, 0.1, 0.4, 0.4]) ax1.plot([0, 1], [0, 1], label="line") - # ref=[ax1, ax2]. loc='upper right' (inset). leg = fig.legend(ref=[ax1, ax2], loc="upper right") assert leg is not None + + +def _decode_panel_span(panel_ax, axis): + ss = panel_ax.get_subplotspec().get_topmost_subplotspec() + r1, r2, c1, c2 = ss._get_rows_columns() + gs = ss.get_gridspec() + if axis == "rows": + r1, r2 = gs._decode_indices(r1, r2, which="h") + return int(r1), int(r2) + if axis == "cols": + c1, c2 = gs._decode_indices(c1, c2, which="w") + return int(c1), int(c2) + raise ValueError(f"Unknown axis {axis!r}.") + + +def _anchor_axis(ref): + if np.iterable(ref) and not isinstance(ref, (str, UAxes)): + return next(iter(ref)) + return ref + + +@pytest.mark.parametrize( + "first_loc, first_ref, second_loc, second_ref, span_axis", + [ + ("b", lambda axs: axs[0], "r", lambda axs: axs[:, 1], "rows"), + ("r", lambda axs: axs[:, 2], "b", lambda axs: axs[1, :], "cols"), + ("t", lambda axs: axs[2], "l", lambda axs: axs[:, 0], "rows"), + ("l", lambda axs: axs[:, 0], "t", lambda axs: axs[1, :], "cols"), + ], +) +def test_legend_span_inference_with_multi_panels( + first_loc, first_ref, second_loc, second_ref, span_axis +): + fig, axs = uplt.subplots(nrows=3, ncols=3) + axs.plot([0, 1], [0, 1], label="line") + + fig.legend(ref=first_ref(axs), loc=first_loc) + fig.legend(ref=second_ref(axs), loc=second_loc) + + side_map = {"l": "left", "r": "right", "t": "top", "b": "bottom"} + anchor = _anchor_axis(second_ref(axs)) + panel_ax = anchor._panel_dict[side_map[second_loc]][-1] + span = _decode_panel_span(panel_ax, span_axis) + assert span == (0, 2) + + +def test_legend_best_axis_selection_right_left(): + fig, axs = uplt.subplots(nrows=1, ncols=3) + axs.plot([0, 1], [0, 1], label="line") + ref = [axs[0, 0], axs[0, 2]] + + fig.legend(ref=ref, loc="r", rows=1) + assert len(axs[0, 2]._panel_dict["right"]) == 1 + assert len(axs[0, 0]._panel_dict["right"]) == 0 + + fig.legend(ref=ref, loc="l", rows=1) + assert len(axs[0, 0]._panel_dict["left"]) == 1 + assert len(axs[0, 2]._panel_dict["left"]) == 0 + + +def test_legend_best_axis_selection_top_bottom(): + fig, axs = uplt.subplots(nrows=2, ncols=1) + axs.plot([0, 1], [0, 1], label="line") + ref = [axs[0, 0], axs[1, 0]] + + fig.legend(ref=ref, loc="t", cols=1) + assert len(axs[0, 0]._panel_dict["top"]) == 1 + assert len(axs[1, 0]._panel_dict["top"]) == 0 + + fig.legend(ref=ref, loc="b", cols=1) + assert len(axs[1, 0]._panel_dict["bottom"]) == 1 + assert len(axs[0, 0]._panel_dict["bottom"]) == 0 + + +def test_legend_span_decode_fallback(monkeypatch): + fig, axs = uplt.subplots(nrows=2, ncols=2) + axs.plot([0, 1], [0, 1], label="line") + ref = axs[:, 0] + + gs = axs[0, 0].get_subplotspec().get_topmost_subplotspec().get_gridspec() + + def _raise_decode(*args, **kwargs): + raise ValueError("forced") + + monkeypatch.setattr(gs, "_decode_indices", _raise_decode) + leg = fig.legend(ref=ref, loc="r") + assert leg is not None + + +def test_legend_span_inference_skips_invalid_ref_axes(): + class DummyNoSpec: + pass + + class DummyNullSpec: + def get_subplotspec(self): + return None + + fig, axs = uplt.subplots(nrows=1, ncols=2) + axs[0].plot([0, 1], [0, 1], label="line") + ref = [DummyNoSpec(), DummyNullSpec(), axs[0]] + + leg = fig.legend(ax=axs[0], ref=ref, loc="r") + assert leg is not None + assert len(axs[0]._panel_dict["right"]) == 1 + + +def test_legend_best_axis_fallback_with_inset_loc(): + fig, axs = uplt.subplots(nrows=1, ncols=2) + axs.plot([0, 1], [0, 1], label="line") + + leg = fig.legend(ref=axs, loc="upper left", rows=1) + assert leg is not None + + +def test_legend_best_axis_fallback_empty_iterable_ref(): + class LegendProxy: + def __init__(self, ax): + self._ax = ax + + def __iter__(self): + return iter(()) + + def legend(self, *args, **kwargs): + return self._ax.legend(*args, **kwargs) + + fig, ax = uplt.subplots() + ax.plot([0, 1], [0, 1], label="line") + proxy = LegendProxy(ax) + + leg = fig.legend(ref=proxy, loc="upper left", rows=1) + assert leg is not None