diff --git a/ultraplot/axes/base.py b/ultraplot/axes/base.py index dba968bff..5354af4f7 100644 --- a/ultraplot/axes/base.py +++ b/ultraplot/axes/base.py @@ -2792,6 +2792,79 @@ def _reposition_subplot(self): self.update_params() setter(self.figbox) # equivalent to above + # In UltraLayout, place panels relative to their parent axes, not the grid. + if ( + self._panel_parent + and self._panel_side + and self.figure.gridspec._use_ultra_layout + ): + gs = self.get_subplotspec().get_gridspec() + figwidth, figheight = self.figure.get_size_inches() + ss = self.get_subplotspec().get_topmost_subplotspec() + row1, row2, col1, col2 = ss._get_rows_columns(ncols=gs.ncols_total) + side = self._panel_side + parent_bbox = self._panel_parent.get_position() + panels = list(self._panel_parent._panel_dict.get(side, ())) + anchor_ax = self._panel_parent + if self in panels: + idx = panels.index(self) + if idx > 0: + anchor_ax = panels[idx - 1] + elif panels: + anchor_ax = panels[-1] + anchor_bbox = anchor_ax.get_position() + anchor_ss = anchor_ax.get_subplotspec().get_topmost_subplotspec() + a_row1, a_row2, a_col1, a_col2 = anchor_ss._get_rows_columns( + ncols=gs.ncols_total + ) + + if side in ("right", "left"): + boundary = None + width = sum(gs._wratios_total[col1 : col2 + 1]) / figwidth + if a_col2 < col1: + boundary = a_col2 + elif col2 < a_col1: + boundary = col2 + # Fall back to an interface adjacent to this panel + boundary = min( + max( + _not_none(boundary, a_col2 if side == "right" else col2), + 0, + ), + len(gs.wspace_total) - 1, + ) + pad = gs.wspace_total[boundary] / figwidth + if side == "right": + x0 = anchor_bbox.x1 + pad + else: + x0 = anchor_bbox.x0 - pad - width + bbox = mtransforms.Bbox.from_bounds( + x0, parent_bbox.y0, width, parent_bbox.height + ) + else: + boundary = None + height = sum(gs._hratios_total[row1 : row2 + 1]) / figheight + if a_row2 < row1: + boundary = a_row2 + elif row2 < a_row1: + boundary = row2 + boundary = min( + max( + _not_none(boundary, a_row2 if side == "top" else row2), + 0, + ), + len(gs.hspace_total) - 1, + ) + pad = gs.hspace_total[boundary] / figheight + if side == "top": + y0 = anchor_bbox.y1 + pad + else: + y0 = anchor_bbox.y0 - pad - height + bbox = mtransforms.Bbox.from_bounds( + parent_bbox.x0, y0, parent_bbox.width, height + ) + setter(bbox) + def _update_abc(self, **kwargs): """ Update the a-b-c label. diff --git a/ultraplot/axes/plot.py b/ultraplot/axes/plot.py index 526e6ffac..6a16d4ee1 100644 --- a/ultraplot/axes/plot.py +++ b/ultraplot/axes/plot.py @@ -3187,9 +3187,9 @@ def _parse_level_lim( for z in zs: if z is None: # e.g. empty scatter color continue + z = inputs._to_numpy_array(z) if z.ndim > 2: # e.g. imshow data continue - z = inputs._to_numpy_array(z) if inbounds and x is not None and y is not None: # ignore if None coords z = self._inbounds_vlim(x, y, z, to_centers=to_centers) imin, imax = inputs._safe_range(z, pmin, pmax) diff --git a/ultraplot/figure.py b/ultraplot/figure.py index 5d302f318..aa2cd7952 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -661,7 +661,9 @@ def __init__( ) self._refnum = refnum self._refaspect = refaspect - self._refaspect_default = 1 # updated for imshow and geographic plots + # Default to a square reference aspect; auto_layout will update this when an + # explicit aspect is detected (e.g., imshow, geographic plots). + self._refaspect_default = 1 self._refwidth = units(refwidth, "in") self._refheight = units(refheight, "in") self._figwidth = figwidth = units(figwidth, "in") @@ -1841,7 +1843,7 @@ def _axes_dict(naxs, input, kw=False, default=None): # Create or update the gridspec and add subplots with subplotspecs # NOTE: The gridspec is added to the figure when we pass the subplotspec if gs is None: - gs = pgridspec.GridSpec(*array.shape, **gridspec_kw) + gs = pgridspec.GridSpec(*array.shape, layout_array=array, **gridspec_kw) else: gs.update(**gridspec_kw) axs = naxs * [None] # list of axes @@ -2399,6 +2401,16 @@ def _align_content(): # noqa: E306 gs._auto_layout_tight(renderer) _align_content() + # Finalize figure size using the latest spaces. If the size changes, update + # layout one more time to minimize surrounding whitespace with the new bounds. + figsize = gs._update_figsize() + eps = 0.01 + if self._refwidth is not None or self._refheight is not None: + eps = 0 + if not self._is_same_size(figsize, eps=eps): + # Use zero tolerance so sub-inch adjustments apply when ref sizes are set. + self.set_size_inches(figsize, internal=True, eps=0) + @warnings._rename_kwargs( "0.10.0", mathtext_fallback="uplt.rc.mathtext_fallback = {}" ) diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 288f1abc4..ca46e26dd 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -25,6 +25,14 @@ ) from .utils import _fontsize_to_pt, units +try: + from . import ultralayout + + ULTRA_AVAILABLE = True +except ImportError: + ultralayout = None + ULTRA_AVAILABLE = False + __all__ = ["GridSpec", "SubplotGrid"] @@ -228,6 +236,20 @@ def get_position(self, figure, return_all=False): nrows, ncols = gs.get_total_geometry() else: nrows, ncols = gs.get_geometry() + + # Check if we should use UltraLayout for this subplot + if isinstance(gs, GridSpec) and gs._use_ultra_layout: + bbox = gs._get_ultra_position(self.num1, figure) + if bbox is not None: + if return_all: + rows, cols = np.unravel_index( + [self.num1, self.num2], (nrows, ncols) + ) + return bbox, rows[0], cols[0], nrows, ncols + else: + return bbox + + # Default behavior: use grid positions rows, cols = np.unravel_index([self.num1, self.num2], (nrows, ncols)) bottoms, tops, lefts, rights = gs.get_grid_positions(figure) bottom = bottoms[rows].min() @@ -267,7 +289,14 @@ def __getattr__(self, attr): super().__getattribute__(attr) # native error message @docstring._snippet_manager - def __init__(self, nrows=1, ncols=1, **kwargs): + def __init__( + self, + nrows=1, + ncols=1, + layout_array=None, + ultra_layout: Optional[bool] = None, + **kwargs, + ): """ Parameters ---------- @@ -275,6 +304,14 @@ def __init__(self, nrows=1, ncols=1, **kwargs): The number of rows in the subplot grid. ncols : int, optional The number of columns in the subplot grid. + layout_array : array-like, optional + 2D array specifying the subplot layout, where each unique integer + represents a subplot and 0 represents empty space. When provided, + enables UltraLayout constraint-based positioning (requires + kiwisolver package). + ultra_layout : bool, optional + Whether to use the UltraLayout constraint solver. Defaults to True + when kiwisolver is available. Set to False to use the legacy solver. Other parameters ---------------- @@ -304,6 +341,27 @@ def __init__(self, nrows=1, ncols=1, **kwargs): manually and want the same geometry for multiple figures, you must create a copy with `GridSpec.copy` before working on the subsequent figure). """ + # Layout array for UltraLayout + self._layout_array = ( + np.array(layout_array) if layout_array is not None else None + ) + self._ultra_positions = None # Cache for UltraLayout-computed positions + self._ultra_layout_array = None # Cache for expanded UltraLayout array + self._use_ultra_layout = False # Flag to enable UltraLayout + + # Check if we should use UltraLayout + if ultra_layout is not None: + self._use_ultra_layout = bool(ultra_layout) and ULTRA_AVAILABLE + elif ULTRA_AVAILABLE: + self._use_ultra_layout = True + if ultra_layout and not ULTRA_AVAILABLE: + warnings._warn_ultraplot( + "ultra_layout=True requested but kiwisolver is not available. " + "Falling back to the legacy layout solver." + ) + if self._use_ultra_layout and self._layout_array is None: + self._layout_array = np.arange(1, nrows * ncols + 1).reshape(nrows, ncols) + # Fundamental GridSpec properties self._nrows_total = nrows self._ncols_total = ncols @@ -366,6 +424,162 @@ def __init__(self, nrows=1, ncols=1, **kwargs): } self._update_params(pad=pad, **kwargs) + def _get_ultra_position(self, subplot_num, figure): + """ + Get the position of a subplot using UltraLayout constraint-based positioning. + + Parameters + ---------- + subplot_num : int + The subplot number (in total geometry indexing) + figure : Figure + The matplotlib figure instance + + Returns + ------- + bbox : Bbox or None + The bounding box for the subplot, or None if kiwi layout fails + """ + if not self._use_ultra_layout or self._layout_array is None: + return None + + # Ensure figure is set + if not self.figure: + self._figure = figure + if not self.figure: + return None + + # Compute or retrieve cached UltraLayout positions + if self._ultra_positions is None: + self._compute_ultra_positions() + if self._ultra_positions is None: + return None + layout_array = self._get_ultra_layout_array() + if layout_array is None: + return None + + # Find which subplot number in the layout array corresponds to this subplot_num + # We need to map from the gridspec cell index to the layout array subplot number + nrows, ncols = layout_array.shape + + # Decode the subplot_num to find which layout number it corresponds to + # This is a bit tricky because subplot_num is in total geometry space + # We need to find which unique number in the layout_array this corresponds to + + # Get the cell position from subplot_num + if (nrows, ncols) == self.get_total_geometry(): + row, col = divmod(subplot_num, self.ncols_total) + else: + decoded = self._decode_indices(subplot_num) + row, col = divmod(decoded, ncols) + + # Check if this is within the layout array bounds + if row >= nrows or col >= ncols: + return None + + # Get the layout number at this position + layout_num = layout_array[row, col] + + if layout_num == 0 or layout_num not in self._ultra_positions: + return None + + # Return the cached position + left, bottom, width, height = self._ultra_positions[layout_num] + bbox = mtransforms.Bbox.from_bounds(left, bottom, width, height) + return bbox + + def _compute_ultra_positions(self): + """ + Compute subplot positions using UltraLayout and cache them. + """ + if not ULTRA_AVAILABLE or self._layout_array is None: + return + layout_array = self._get_ultra_layout_array() + if layout_array is None: + return + + # Get figure size + if not self.figure: + return + + figwidth, figheight = self.figure.get_size_inches() + + # Convert spacing to inches (including default ticklabel sizes). + wspace_inches = list(self.wspace_total) + hspace_inches = list(self.hspace_total) + + # Get margins + left = self.left + right = self.right + top = self.top + bottom = self.bottom + + # Compute positions using UltraLayout + try: + self._ultra_positions = ultralayout.compute_ultra_positions( + layout_array, + figwidth=figwidth, + figheight=figheight, + wspace=wspace_inches, + hspace=hspace_inches, + left=left, + right=right, + top=top, + bottom=bottom, + wratios=self._wratios_total, + hratios=self._hratios_total, + wpanels=[bool(val) for val in self._wpanels], + hpanels=[bool(val) for val in self._hpanels], + ) + except Exception as e: + warnings._warn_ultraplot( + f"Failed to compute UltraLayout: {e}. " + "Falling back to default grid layout." + ) + self._use_ultra_layout = False + self._ultra_positions = None + + def _get_ultra_layout_array(self): + """ + Return the layout array expanded to total geometry to include panels. + """ + if self._layout_array is None: + return None + if self._ultra_layout_array is not None: + return self._ultra_layout_array + + nrows_total, ncols_total = self.get_total_geometry() + layout = self._layout_array + if layout.shape == (nrows_total, ncols_total): + self._ultra_layout_array = layout + return layout + + nrows, ncols = self.get_geometry() + if layout.shape != (nrows, ncols): + warnings._warn_ultraplot( + "Layout array shape does not match gridspec geometry; " + "using the original layout array for UltraLayout." + ) + self._ultra_layout_array = layout + return layout + + row_idxs = self._get_indices("h", panel=False) + col_idxs = self._get_indices("w", panel=False) + if len(row_idxs) != nrows or len(col_idxs) != ncols: + warnings._warn_ultraplot( + "Layout array shape does not match non-panel gridspec geometry; " + "using the original layout array for UltraLayout." + ) + self._ultra_layout_array = layout + return layout + + expanded = np.zeros((nrows_total, ncols_total), dtype=layout.dtype) + for i, row_idx in enumerate(row_idxs): + for j, col_idx in enumerate(col_idxs): + expanded[row_idx, col_idx] = layout[i, j] + self._ultra_layout_array = expanded + return expanded + def __getitem__(self, key): """ Get a `~matplotlib.gridspec.SubplotSpec`. "Hidden" slots allocated for axes @@ -489,6 +703,9 @@ def _modify_subplot_geometry(self, newrow=None, newcol=None): """ Update the axes subplot specs by inserting rows and columns as specified. """ + if self._use_ultra_layout: + self._ultra_positions = None + self._ultra_layout_array = None fig = self.figure ncols = self._ncols_total - int(newcol is not None) # previous columns inserts = (newrow, newrow, newcol, newcol) @@ -964,8 +1181,11 @@ def _auto_layout_aspect(self): # Update the layout figsize = self._update_figsize() - if not fig._is_same_size(figsize): - fig.set_size_inches(figsize, internal=True) + eps = 0.01 + if fig._refwidth is not None or fig._refheight is not None: + eps = 0 + if not fig._is_same_size(figsize, eps=eps): + fig.set_size_inches(figsize, internal=True, eps=0) def _auto_layout_tight(self, renderer): """ @@ -1023,8 +1243,11 @@ def _auto_layout_tight(self, renderer): # spaces (necessary since native position coordinates are figure-relative) # and to enforce fixed panel ratios. So only self.update() if we skip resize. figsize = self._update_figsize() - if not fig._is_same_size(figsize): - fig.set_size_inches(figsize, internal=True) + eps = 0.01 + if fig._refwidth is not None or fig._refheight is not None: + eps = 0 # force resize when explicit reference sizing is requested + if not fig._is_same_size(figsize, eps=eps): + fig.set_size_inches(figsize, internal=True, eps=0) else: self.update() @@ -1041,14 +1264,14 @@ def _update_figsize(self): return ss = ax.get_subplotspec().get_topmost_subplotspec() y1, y2, x1, x2 = ss._get_rows_columns() - refhspace = sum(self.hspace_total[y1:y2]) - refwspace = sum(self.wspace_total[x1:x2]) - refhpanel = sum( - self.hratios_total[i] for i in range(y1, y2 + 1) if self._hpanels[i] - ) # noqa: E501 - refwpanel = sum( - self.wratios_total[i] for i in range(x1, x2 + 1) if self._wpanels[i] - ) # noqa: E501 + # NOTE: Reference width/height should correspond to the span of the *axes* + # themselves. Spaces between rows/columns and adjacent panel slots should + # not reduce the target size; those are accounted for separately when the + # full figure size is rebuilt below. + refhspace = 0 + refwspace = 0 + refhpanel = 0 + refwpanel = 0 refhsubplot = sum( self.hratios_total[i] for i in range(y1, y2 + 1) if not self._hpanels[i] ) # noqa: E501 @@ -1060,6 +1283,10 @@ def _update_figsize(self): # NOTE: The sizing arguments should have been normalized already figwidth, figheight = fig._figwidth, fig._figheight refwidth, refheight = fig._refwidth, fig._refheight + if refwidth is not None: + figwidth = None # prefer explicit reference sizing over preset fig size + if refheight is not None: + figheight = None refaspect = _not_none(fig._refaspect, fig._refaspect_default) if refheight is None and figheight is None: if figwidth is not None: @@ -1090,6 +1317,15 @@ def _update_figsize(self): gridwidth = refwidth * self.gridwidth / refwsubplot figwidth = gridwidth + self.spacewidth + self.panelwidth + # Snap explicit reference-driven sizes to the pixel grid to avoid + # rounding the axes width below the requested reference size. + if fig and (fig._refwidth is not None or fig._refheight is not None): + dpi = _not_none(getattr(fig, "dpi", None), 72) + if figwidth is not None: + figwidth = round(figwidth * dpi) / dpi + if figheight is not None: + figheight = round(figheight * dpi) / dpi + # Return the figure size figsize = (figwidth, figheight) if all(np.isfinite(figsize)): @@ -1100,6 +1336,7 @@ def _update_figsize(self): def _update_params( self, *, + ultra_layout=None, left=None, bottom=None, right=None, @@ -1127,6 +1364,20 @@ def _update_params( """ Update the user-specified properties. """ + if ultra_layout is not None: + self._use_ultra_layout = bool(ultra_layout) and ULTRA_AVAILABLE + if ultra_layout and not ULTRA_AVAILABLE: + warnings._warn_ultraplot( + "ultra_layout=True requested but kiwisolver is not available. " + "Falling back to the legacy layout solver." + ) + if self._use_ultra_layout and self._layout_array is None: + nrows, ncols = self.get_geometry() + self._layout_array = np.arange(1, nrows * ncols + 1).reshape( + nrows, ncols + ) + self._ultra_positions = None + self._ultra_layout_array = None # Assign scalar args # WARNING: The key signature here is critical! Used in ui.py to @@ -1219,7 +1470,12 @@ def copy(self, **kwargs): # WARNING: For some reason copy.copy() fails. Updating e.g. wpanels # and hpanels on the copy also updates this object. No idea why. nrows, ncols = self.get_geometry() - gs = GridSpec(nrows, ncols) + gs = GridSpec( + nrows, + ncols, + layout_array=self._layout_array, + ultra_layout=self._use_ultra_layout, + ) hidxs = self._get_indices("h") widxs = self._get_indices("w") gs._hratios_total = [self._hratios_total[i] for i in hidxs] @@ -1384,6 +1640,9 @@ def update(self, **kwargs): # Apply positions to all axes # NOTE: This uses the current figure size to fix panel widths # and determine physical grid spacing. + if self._use_ultra_layout: + self._ultra_positions = None + self._ultra_layout_array = None self._update_params(**kwargs) fig = self.figure if fig is None: @@ -1439,8 +1698,30 @@ def figure(self, fig): get_height_ratios = _disable_method("get_height_ratios") set_width_ratios = _disable_method("set_width_ratios") set_height_ratios = _disable_method("set_height_ratios") - get_subplot_params = _disable_method("get_subplot_params") - locally_modified_subplot_params = _disable_method("locally_modified_subplot_params") + + # Compat: some backends (e.g., Positron) call these for read-only checks. + # We return current margins/spaces without permitting mutation. + def get_subplot_params(self, figure=None): + from matplotlib.figure import SubplotParams + + fig = figure or self.figure + if fig is None: + raise RuntimeError("Figure must be assigned to gridspec.") + # Convert absolute margins to figure-relative floats + width, height = fig.get_size_inches() + left = self.left / width + right = 1 - self.right / width + bottom = self.bottom / height + top = 1 - self.top / height + wspace = sum(self.wspace_total) / width + hspace = sum(self.hspace_total) / height + return SubplotParams( + left=left, right=right, bottom=bottom, top=top, wspace=wspace, hspace=hspace + ) + + def locally_modified_subplot_params(self): + # Backend probe: report False/None semantics (no local mods to MPL params). + return False # Immutable helper properties used to calculate figure size and subplot positions # NOTE: The spaces are auto-filled with defaults wherever user left them unset diff --git a/ultraplot/tests/test_imshow.py b/ultraplot/tests/test_imshow.py index 882deb2de..0f74029e2 100644 --- a/ultraplot/tests/test_imshow.py +++ b/ultraplot/tests/test_imshow.py @@ -1,8 +1,9 @@ +import numpy as np import pytest - -import ultraplot as plt, numpy as np from matplotlib.testing import setup +import ultraplot as plt + @pytest.fixture() def setup_mpl(): diff --git a/ultraplot/tests/test_ultralayout.py b/ultraplot/tests/test_ultralayout.py new file mode 100644 index 000000000..b9b763d53 --- /dev/null +++ b/ultraplot/tests/test_ultralayout.py @@ -0,0 +1,320 @@ +import numpy as np +import pytest + +import ultraplot as uplt +from ultraplot import ultralayout +from ultraplot.gridspec import GridSpec + + +def test_is_orthogonal_layout_simple_grid(): + """Test orthogonal layout detection for simple grids.""" + # Simple 2x2 grid should be orthogonal + array = np.array([[1, 2], [3, 4]]) + assert ultralayout.is_orthogonal_layout(array) is True + + +def test_is_orthogonal_layout_non_orthogonal(): + """Test orthogonal layout detection for non-orthogonal layouts.""" + # Centered subplot with empty cells should be non-orthogonal + array = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + assert ultralayout.is_orthogonal_layout(array) is False + + +def test_is_orthogonal_layout_spanning(): + """Test orthogonal layout with spanning subplots that is still orthogonal.""" + # L-shape that maintains grid alignment + array = np.array([[1, 1], [1, 2]]) + assert ultralayout.is_orthogonal_layout(array) is True + + +def test_is_orthogonal_layout_with_gaps(): + """Test non-orthogonal layout with gaps.""" + array = np.array([[1, 1, 1], [2, 0, 3]]) + assert ultralayout.is_orthogonal_layout(array) is False + + +def test_is_orthogonal_layout_empty(): + """Test empty layout.""" + array = np.array([[0, 0], [0, 0]]) + assert ultralayout.is_orthogonal_layout(array) is True + + +def test_gridspec_with_orthogonal_layout(): + """Test that GridSpec activates UltraLayout for orthogonal layouts.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2], [3, 4]]) + gs = GridSpec(2, 2, layout_array=layout) + assert gs._layout_array is not None + # Should use UltraLayout for orthogonal layouts + assert gs._use_ultra_layout is True + + +def test_gridspec_with_non_orthogonal_layout(): + """Test that GridSpec activates UltraLayout for non-orthogonal layouts.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs = GridSpec(2, 4, layout_array=layout) + assert gs._layout_array is not None + # Should use UltraLayout for non-orthogonal layouts + assert gs._use_ultra_layout is True + + +def test_gridspec_without_kiwisolver(monkeypatch): + """Test graceful fallback when kiwisolver is not available.""" + # Mock the ULTRA_AVAILABLE flag + import ultraplot.gridspec as gs_module + + monkeypatch.setattr(gs_module, "ULTRA_AVAILABLE", False) + + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs = GridSpec(2, 4, layout_array=layout) + # Should not activate UltraLayout if kiwisolver not available + assert gs._use_ultra_layout is False + + +def test_gridspec_ultralayout_opt_out(): + """Test that UltraLayout can be disabled explicitly.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2], [3, 4]]) + gs = GridSpec(2, 2, layout_array=layout, ultra_layout=False) + assert gs._use_ultra_layout is False + + +def test_gridspec_default_layout_array_with_ultralayout(): + """Test that UltraLayout initializes a default layout array.""" + pytest.importorskip("kiwisolver") + gs = GridSpec(2, 3) + assert gs._layout_array is not None + assert gs._layout_array.shape == (2, 3) + assert gs._use_ultra_layout is True + + +def test_ultralayout_solver_initialization(): + """Test UltraLayoutSolver can be initialized.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + solver = ultralayout.UltraLayoutSolver(layout, figwidth=10.0, figheight=6.0) + assert solver.array is not None + assert solver.nrows == 2 + assert solver.ncols == 4 + + +def test_compute_ultra_positions(): + """Test computing positions with UltraLayout.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + positions = ultralayout.compute_ultra_positions( + layout, + figwidth=10.0, + figheight=6.0, + wspace=[0.2, 0.2, 0.2], + hspace=[0.2], + ) + + # Should return positions for 3 subplots + assert len(positions) == 3 + assert 1 in positions + assert 2 in positions + assert 3 in positions + + # Each position should be (left, bottom, width, height) + for num, pos in positions.items(): + assert len(pos) == 4 + left, bottom, width, height = pos + assert 0 <= left <= 1 + assert 0 <= bottom <= 1 + assert width > 0 + assert height > 0 + assert left + width <= 1.01 # Allow small numerical error + assert bottom + height <= 1.01 + + +def test_subplots_with_non_orthogonal_layout(): + """Test creating subplots with non-orthogonal layout.""" + pytest.importorskip("kiwisolver") + layout = [[1, 1, 2, 2], [0, 3, 3, 0]] + fig, axs = uplt.subplots(array=layout, figsize=(10, 6)) + + # Should create 3 subplots + assert len(axs) == 3 + + # Check that positions are valid + for ax in axs: + pos = ax.get_position() + assert pos.width > 0 + assert pos.height > 0 + assert 0 <= pos.x0 <= 1 + assert 0 <= pos.y0 <= 1 + + +def test_subplots_with_orthogonal_layout(): + """Test creating subplots with orthogonal layout (should work as before).""" + layout = [[1, 2], [3, 4]] + fig, axs = uplt.subplots(array=layout, figsize=(8, 6)) + + # Should create 4 subplots + assert len(axs) == 4 + + # Check that positions are valid + for ax in axs: + pos = ax.get_position() + assert pos.width > 0 + assert pos.height > 0 + + +def test_ultralayout_respects_spacing(): + """Test that UltraLayout respects spacing parameters.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + + # Compute with different spacing + positions1 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wspace=[0.1, 0.1, 0.1], hspace=[0.1] + ) + positions2 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wspace=[0.5, 0.5, 0.5], hspace=[0.5] + ) + + # Subplots should be smaller with more spacing + for num in [1, 2, 3]: + _, _, width1, height1 = positions1[num] + _, _, width2, height2 = positions2[num] + # With more spacing, subplots should be smaller + assert width2 < width1 or height2 < height1 + + +def test_ultralayout_respects_ratios(): + """Test that UltraLayout respects width/height ratios.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2], [3, 4]]) + + # Equal ratios + positions1 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wratios=[1, 1], hratios=[1, 1] + ) + + # Unequal ratios + positions2 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, wratios=[1, 2], hratios=[1, 1] + ) + + # Subplot 2 should be wider than subplot 1 with unequal ratios + _, _, width1_1, _ = positions1[1] + _, _, width1_2, _ = positions1[2] + _, _, width2_1, _ = positions2[1] + _, _, width2_2, _ = positions2[2] + + # With equal ratios, widths should be similar + assert abs(width1_1 - width1_2) < 0.01 + # With 1:2 ratio, second should be roughly twice as wide + assert width2_2 > width2_1 + + +def test_ultralayout_with_panels_uses_total_geometry(): + """Test UltraLayout accounts for panel slots in total geometry.""" + pytest.importorskip("kiwisolver") + layout = [[1, 1, 2, 2], [0, 3, 3, 0]] + fig, axs = uplt.subplots(array=layout, figsize=(8, 6)) + + # Add a colorbar to introduce panel slots + mappable = axs[0].imshow([[0, 1], [2, 3]]) + fig.colorbar(mappable, loc="r") + + gs = fig.gridspec + gs._compute_ultra_positions() + assert gs._ultra_layout_array.shape == gs.get_total_geometry() + + row_idxs = gs._get_indices("h", panel=False) + col_idxs = gs._get_indices("w", panel=False) + for i, row_idx in enumerate(row_idxs): + for j, col_idx in enumerate(col_idxs): + assert gs._ultra_layout_array[row_idx, col_idx] == gs._layout_array[i, j] + + ss = axs[0].get_subplotspec() + assert gs._get_ultra_position(ss.num1, fig) is not None + + +def test_ultralayout_cached_positions(): + """Test that UltraLayout positions are cached in GridSpec.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs = GridSpec(2, 4, layout_array=layout) + + # Positions should not be computed yet + assert gs._ultra_positions is None + + # Create a figure to trigger position computation + fig = uplt.figure() + gs._figure = fig + + # Access a position (this should trigger computation) + ss = gs[0, 0] + pos = ss.get_position(fig) + + # Positions should now be cached + assert gs._ultra_positions is not None + assert len(gs._ultra_positions) == 3 + + +def test_ultralayout_with_margins(): + """Test that UltraLayout respects margin parameters.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 2]]) + + # Small margins + positions1 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, left=0.1, right=0.1, top=0.1, bottom=0.1 + ) + + # Large margins + positions2 = ultralayout.compute_ultra_positions( + layout, figwidth=10.0, figheight=6.0, left=1.0, right=1.0, top=1.0, bottom=1.0 + ) + + # With larger margins, subplots should be smaller + for num in [1, 2]: + _, _, width1, height1 = positions1[num] + _, _, width2, height2 = positions2[num] + assert width2 < width1 + assert height2 < height1 + + +def test_complex_non_orthogonal_layout(): + """Test a more complex non-orthogonal layout.""" + pytest.importorskip("kiwisolver") + layout = np.array([[1, 1, 1, 2], [3, 3, 0, 2], [4, 5, 5, 5]]) + + positions = ultralayout.compute_ultra_positions( + layout, figwidth=12.0, figheight=9.0 + ) + + # Should have 5 subplots + assert len(positions) == 5 + + # All positions should be valid + for num in range(1, 6): + assert num in positions + left, bottom, width, height = positions[num] + assert 0 <= left <= 1 + assert 0 <= bottom <= 1 + assert width > 0 + assert height > 0 + + +def test_ultralayout_module_exports(): + """Test that ultralayout module exports expected symbols.""" + assert hasattr(ultralayout, "UltraLayoutSolver") + assert hasattr(ultralayout, "compute_ultra_positions") + assert hasattr(ultralayout, "is_orthogonal_layout") + assert hasattr(ultralayout, "get_grid_positions_ultra") + + +def test_gridspec_copy_preserves_layout_array(): + """Test that copying a GridSpec preserves the layout array.""" + layout = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + gs1 = GridSpec(2, 4, layout_array=layout) + gs2 = gs1.copy() + + assert gs2._layout_array is not None + assert np.array_equal(gs1._layout_array, gs2._layout_array) + assert gs1._use_ultra_layout == gs2._use_ultra_layout diff --git a/ultraplot/ultralayout.py b/ultraplot/ultralayout.py new file mode 100644 index 000000000..be945d6ab --- /dev/null +++ b/ultraplot/ultralayout.py @@ -0,0 +1,555 @@ +#!/usr/bin/env python3 +""" +UltraLayout: Advanced constraint-based layout system for non-orthogonal subplot arrangements. + +This module provides UltraPlot's constraint-based layout computation for subplot grids +that don't follow simple orthogonal patterns, such as [[1, 1, 2, 2], [0, 3, 3, 0]] +where subplot 3 should be nicely centered between subplots 1 and 2. +""" + +from typing import Dict, List, Optional, Tuple + +import numpy as np + +try: + from kiwisolver import Solver, Variable + + KIWI_AVAILABLE = True +except ImportError: + KIWI_AVAILABLE = False + Variable = None + Solver = None + + +__all__ = [ + "UltraLayoutSolver", + "compute_ultra_positions", + "get_grid_positions_ultra", + "is_orthogonal_layout", +] + + +def is_orthogonal_layout(array: np.ndarray) -> bool: + """ + Check if a subplot array follows an orthogonal (grid-aligned) layout. + + An orthogonal layout is one where every subplot's edges align with + other subplots' edges, forming a simple grid. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers (with 0 for empty cells) + + Returns + ------- + bool + True if layout is orthogonal, False otherwise + """ + if array.size == 0: + return True + + # Get unique subplot numbers (excluding 0) + subplot_nums = np.unique(array[array != 0]) + + if len(subplot_nums) == 0: + return True + + # For each subplot, get its bounding box + bboxes = {} + for num in subplot_nums: + rows, cols = np.where(array == num) + bboxes[num] = { + "row_min": rows.min(), + "row_max": rows.max(), + "col_min": cols.min(), + "col_max": cols.max(), + } + + # Check if layout is orthogonal by verifying that all vertical and + # horizontal edges align with cell boundaries + # A more sophisticated check: for each row/col boundary, check if + # all subplots either cross it or are completely on one side + + # Collect all unique row and column boundaries + row_boundaries = set() + col_boundaries = set() + + for bbox in bboxes.values(): + row_boundaries.add(bbox["row_min"]) + row_boundaries.add(bbox["row_max"] + 1) + col_boundaries.add(bbox["col_min"]) + col_boundaries.add(bbox["col_max"] + 1) + + # Check if these boundaries create a consistent grid + # For orthogonal layout, we should be able to split the grid + # using these boundaries such that each subplot is a union of cells + + row_boundaries = sorted(row_boundaries) + col_boundaries = sorted(col_boundaries) + + # Create a refined grid + refined_rows = len(row_boundaries) - 1 + refined_cols = len(col_boundaries) - 1 + + if refined_rows == 0 or refined_cols == 0: + return True + + # Map each subplot to refined grid cells + for num in subplot_nums: + rows, cols = np.where(array == num) + + # Check if this subplot occupies a rectangular region in the refined grid + refined_row_indices = set() + refined_col_indices = set() + + for r in rows: + for i, (r_start, r_end) in enumerate( + zip(row_boundaries[:-1], row_boundaries[1:]) + ): + if r_start <= r < r_end: + refined_row_indices.add(i) + + for c in cols: + for i, (c_start, c_end) in enumerate( + zip(col_boundaries[:-1], col_boundaries[1:]) + ): + if c_start <= c < c_end: + refined_col_indices.add(i) + + # Check if indices form a rectangle + if refined_row_indices and refined_col_indices: + r_min, r_max = min(refined_row_indices), max(refined_row_indices) + c_min, c_max = min(refined_col_indices), max(refined_col_indices) + + expected_cells = (r_max - r_min + 1) * (c_max - c_min + 1) + actual_cells = len(refined_row_indices) * len(refined_col_indices) + + if expected_cells != actual_cells: + return False + + return True + + +class UltraLayoutSolver: + """ + UltraLayout: Constraint-based layout solver using kiwisolver for subplot positioning. + + This solver computes aesthetically pleasing positions for subplots in + non-orthogonal arrangements by using constraint satisfaction, providing + a superior layout experience for complex subplot arrangements. + """ + + def __init__( + self, + array: np.ndarray, + figwidth: float = 10.0, + figheight: float = 8.0, + wspace: Optional[List[float]] = None, + hspace: Optional[List[float]] = None, + left: float = 0.125, + right: float = 0.125, + top: float = 0.125, + bottom: float = 0.125, + wratios: Optional[List[float]] = None, + hratios: Optional[List[float]] = None, + wpanels: Optional[List[bool]] = None, + hpanels: Optional[List[bool]] = None, + ): + """ + Initialize the UltraLayout solver. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers (with 0 for empty cells) + figwidth, figheight : float + Figure dimensions in inches + wspace, hspace : list of float, optional + Spacing between columns and rows in inches + left, right, top, bottom : float + Margins in inches + wratios, hratios : list of float, optional + Width and height ratios for columns and rows + wpanels, hpanels : list of bool, optional + Flags indicating panel columns or rows with fixed widths/heights. + """ + if not KIWI_AVAILABLE: + raise ImportError( + "kiwisolver is required for non-orthogonal layouts. " + "Install it with: pip install kiwisolver" + ) + + self.array = array + self.nrows, self.ncols = array.shape + self.figwidth = figwidth + self.figheight = figheight + self.left_margin = left + self.right_margin = right + self.top_margin = top + self.bottom_margin = bottom + + # Get subplot numbers + self.subplot_nums = sorted(np.unique(array[array != 0])) + + # Set up spacing + if wspace is None: + self.wspace = [0.2] * (self.ncols - 1) if self.ncols > 1 else [] + else: + self.wspace = list(wspace) + + if hspace is None: + self.hspace = [0.2] * (self.nrows - 1) if self.nrows > 1 else [] + else: + self.hspace = list(hspace) + + # Set up ratios + if wratios is None: + self.wratios = [1.0] * self.ncols + else: + self.wratios = list(wratios) + + if hratios is None: + self.hratios = [1.0] * self.nrows + else: + self.hratios = list(hratios) + + # Set up panel flags (True for fixed-width panel slots). + if wpanels is None: + self.wpanels = [False] * self.ncols + else: + if len(wpanels) != self.ncols: + raise ValueError("wpanels length must match number of columns.") + self.wpanels = [bool(val) for val in wpanels] + if hpanels is None: + self.hpanels = [False] * self.nrows + else: + if len(hpanels) != self.nrows: + raise ValueError("hpanels length must match number of rows.") + self.hpanels = [bool(val) for val in hpanels] + + # Initialize solver + self.solver = Solver() + self._setup_variables() + self._setup_constraints() + + def _setup_variables(self): + """Create kiwisolver variables for all grid lines.""" + # Vertical lines (left edges of columns + right edge of last column) + self.col_lefts = [Variable(f"col_{i}_left") for i in range(self.ncols)] + self.col_rights = [Variable(f"col_{i}_right") for i in range(self.ncols)] + + # Horizontal lines (top edges of rows + bottom edge of last row) + # Note: in figure coordinates, top is higher value + self.row_tops = [Variable(f"row_{i}_top") for i in range(self.nrows)] + self.row_bottoms = [Variable(f"row_{i}_bottom") for i in range(self.nrows)] + + def _setup_constraints(self): + """Set up all constraints for the layout.""" + # 1. Figure boundary constraints + self.solver.addConstraint(self.col_lefts[0] == self.left_margin / self.figwidth) + self.solver.addConstraint( + self.col_rights[-1] == 1.0 - self.right_margin / self.figwidth + ) + self.solver.addConstraint( + self.row_bottoms[-1] == self.bottom_margin / self.figheight + ) + self.solver.addConstraint( + self.row_tops[0] == 1.0 - self.top_margin / self.figheight + ) + + # 2. Column continuity and spacing constraints + for i in range(self.ncols - 1): + # Right edge of column i connects to left edge of column i+1 with spacing + spacing = self.wspace[i] / self.figwidth if i < len(self.wspace) else 0 + self.solver.addConstraint( + self.col_rights[i] + spacing == self.col_lefts[i + 1] + ) + + # 3. Row continuity and spacing constraints + for i in range(self.nrows - 1): + # Bottom edge of row i connects to top edge of row i+1 with spacing + spacing = self.hspace[i] / self.figheight if i < len(self.hspace) else 0 + self.solver.addConstraint( + self.row_bottoms[i] == self.row_tops[i + 1] + spacing + ) + + # 4. Width constraints (panel slots are fixed, remaining slots use ratios) + total_width = 1.0 - (self.left_margin + self.right_margin) / self.figwidth + if self.ncols > 1: + spacing_total = sum(self.wspace) / self.figwidth + else: + spacing_total = 0 + available_width = total_width - spacing_total + fixed_width = 0.0 + ratio_sum = 0.0 + for i in range(self.ncols): + if self.wpanels[i]: + fixed_width += self.wratios[i] / self.figwidth + else: + ratio_sum += self.wratios[i] + remaining_width = max(0.0, available_width - fixed_width) + if ratio_sum == 0: + ratio_sum = 1.0 + + for i in range(self.ncols): + if self.wpanels[i]: + width = self.wratios[i] / self.figwidth + else: + width = remaining_width * self.wratios[i] / ratio_sum + self.solver.addConstraint(self.col_rights[i] == self.col_lefts[i] + width) + + # 5. Height constraints (panel slots are fixed, remaining slots use ratios) + total_height = 1.0 - (self.top_margin + self.bottom_margin) / self.figheight + if self.nrows > 1: + spacing_total = sum(self.hspace) / self.figheight + else: + spacing_total = 0 + available_height = total_height - spacing_total + fixed_height = 0.0 + ratio_sum = 0.0 + for i in range(self.nrows): + if self.hpanels[i]: + fixed_height += self.hratios[i] / self.figheight + else: + ratio_sum += self.hratios[i] + remaining_height = max(0.0, available_height - fixed_height) + if ratio_sum == 0: + ratio_sum = 1.0 + + for i in range(self.nrows): + if self.hpanels[i]: + height = self.hratios[i] / self.figheight + else: + height = remaining_height * self.hratios[i] / ratio_sum + self.solver.addConstraint(self.row_tops[i] == self.row_bottoms[i] + height) + + def solve(self) -> Dict[int, Tuple[float, float, float, float]]: + """ + Solve the constraint system and return subplot positions. + + Returns + ------- + dict + Dictionary mapping subplot numbers to (left, bottom, width, height) + in figure-relative coordinates [0, 1] + """ + # Solve the constraint system + self.solver.updateVariables() + + # Extract positions for each subplot + positions = {} + col_lefts = [v.value() for v in self.col_lefts] + col_rights = [v.value() for v in self.col_rights] + row_tops = [v.value() for v in self.row_tops] + row_bottoms = [v.value() for v in self.row_bottoms] + col_widths = [right - left for left, right in zip(col_lefts, col_rights)] + row_heights = [top - bottom for top, bottom in zip(row_tops, row_bottoms)] + + base_wgap = None + for i in range(self.ncols - 1): + if not self.wpanels[i] and not self.wpanels[i + 1]: + gap = col_lefts[i + 1] - col_rights[i] + if base_wgap is None or gap < base_wgap: + base_wgap = gap + if base_wgap is None: + base_wgap = 0.0 + + base_hgap = None + for i in range(self.nrows - 1): + if not self.hpanels[i] and not self.hpanels[i + 1]: + gap = row_bottoms[i] - row_tops[i + 1] + if base_hgap is None or gap < base_hgap: + base_hgap = gap + if base_hgap is None: + base_hgap = 0.0 + + def _adjust_span( + spans: List[int], + start: float, + end: float, + sizes: List[float], + panels: List[bool], + base_gap: float, + ) -> Tuple[float, float]: + effective = [i for i in spans if not panels[i]] + if len(effective) <= 1: + return start, end + desired = sum(sizes[i] for i in effective) + # Collapse inter-column/row gaps inside spans to keep widths consistent. + # This avoids widening subplots that cross internal panel slots. + full = end - start + if desired < full: + offset = 0.5 * (full - desired) + start = start + offset + end = start + desired + return start, end + + for num in self.subplot_nums: + rows, cols = np.where(self.array == num) + row_min, row_max = rows.min(), rows.max() + col_min, col_max = cols.min(), cols.max() + + # Get the bounding box from the grid lines + left = col_lefts[col_min] + right = col_rights[col_max] + bottom = row_bottoms[row_max] + top = row_tops[row_min] + + span_cols = list(range(col_min, col_max + 1)) + span_rows = list(range(row_min, row_max + 1)) + + left, right = _adjust_span( + span_cols, + left, + right, + col_widths, + self.wpanels, + base_wgap, + ) + top, bottom = _adjust_span( + span_rows, + top, + bottom, + row_heights, + self.hpanels, + base_hgap, + ) + + width = right - left + height = top - bottom + + positions[num] = (left, bottom, width, height) + + return positions + + +def compute_ultra_positions( + array: np.ndarray, + figwidth: float = 10.0, + figheight: float = 8.0, + wspace: Optional[List[float]] = None, + hspace: Optional[List[float]] = None, + left: float = 0.125, + right: float = 0.125, + top: float = 0.125, + bottom: float = 0.125, + wratios: Optional[List[float]] = None, + hratios: Optional[List[float]] = None, + wpanels: Optional[List[bool]] = None, + hpanels: Optional[List[bool]] = None, +) -> Dict[int, Tuple[float, float, float, float]]: + """ + Compute subplot positions using UltraLayout for non-orthogonal layouts. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers (with 0 for empty cells) + figwidth, figheight : float + Figure dimensions in inches + wspace, hspace : list of float, optional + Spacing between columns and rows in inches + left, right, top, bottom : float + Margins in inches + wratios, hratios : list of float, optional + Width and height ratios for columns and rows + wpanels, hpanels : list of bool, optional + Flags indicating panel columns or rows with fixed widths/heights. + + Returns + ------- + dict + Dictionary mapping subplot numbers to (left, bottom, width, height) + in figure-relative coordinates [0, 1] + + Examples + -------- + >>> array = np.array([[1, 1, 2, 2], [0, 3, 3, 0]]) + >>> positions = compute_ultra_positions(array) + >>> positions[3] # Position of subplot 3 + (0.25, 0.125, 0.5, 0.35) + """ + solver = UltraLayoutSolver( + array, + figwidth, + figheight, + wspace, + hspace, + left, + right, + top, + bottom, + wratios, + hratios, + wpanels, + hpanels, + ) + return solver.solve() + + +def get_grid_positions_ultra( + array: np.ndarray, + figwidth: float, + figheight: float, + wspace: Optional[List[float]] = None, + hspace: Optional[List[float]] = None, + left: float = 0.125, + right: float = 0.125, + top: float = 0.125, + bottom: float = 0.125, + wratios: Optional[List[float]] = None, + hratios: Optional[List[float]] = None, + wpanels: Optional[List[bool]] = None, + hpanels: Optional[List[bool]] = None, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + """ + Get grid line positions using UltraLayout. + + This returns arrays of grid line positions similar to GridSpec.get_grid_positions(), + but computed using UltraLayout's constraint satisfaction for better handling of non-orthogonal layouts. + + Parameters + ---------- + array : np.ndarray + 2D array of subplot numbers + figwidth, figheight : float + Figure dimensions in inches + wspace, hspace : list of float, optional + Spacing between columns and rows in inches + left, right, top, bottom : float + Margins in inches + wratios, hratios : list of float, optional + Width and height ratios for columns and rows + wpanels, hpanels : list of bool, optional + Flags indicating panel columns or rows with fixed widths/heights. + + Returns + ------- + bottoms, tops, lefts, rights : np.ndarray + Arrays of grid line positions for each cell + """ + solver = UltraLayoutSolver( + array, + figwidth, + figheight, + wspace, + hspace, + left, + right, + top, + bottom, + wratios, + hratios, + wpanels, + hpanels, + ) + solver.solver.updateVariables() + + # Extract grid line positions + lefts = np.array([v.value() for v in solver.col_lefts]) + rights = np.array([v.value() for v in solver.col_rights]) + tops = np.array([v.value() for v in solver.row_tops]) + bottoms = np.array([v.value() for v in solver.row_bottoms]) + + return bottoms, tops, lefts, rights