diff --git a/celerpy/model.py b/celerpy/model.py index b37a56b..daece34 100644 --- a/celerpy/model.py +++ b/celerpy/model.py @@ -89,7 +89,7 @@ class ImageInput(_Model): rightward: Real3 = [1, 0, 0] "Ray trace direction which points to the right in the image" - vertical_pixels: NonNegativeInt + vertical_pixels: NonNegativeInt = 512 "Number of pixels along the y axis" horizontal_divisor: Optional[PositiveInt] = None diff --git a/celerpy/visualize.py b/celerpy/visualize.py index 67d84a9..0167ef9 100644 --- a/celerpy/visualize.py +++ b/celerpy/visualize.py @@ -7,12 +7,12 @@ import json import re import warnings -from collections.abc import Mapping, MutableSequence +from collections.abc import Iterable, Mapping, MutableSequence from importlib.resources import files from pathlib import Path from subprocess import TimeoutExpired from tempfile import NamedTemporaryFile -from typing import Any, Optional +from typing import Any, Optional, Union import matplotlib.pyplot as plt import numpy as np @@ -307,25 +307,32 @@ def __call__( def plot_all_geometry( - trace_image: Imager, *, colorbar=True, figsize=None + trace_image: Imager, + *, + colorbar: bool = True, + figsize: Optional[tuple] = None, + engines: Optional[Iterable] = None, ) -> Mapping[model.GeometryEngine, Any]: """Convenience function for plotting all available geometry types.""" - width_ratios = [1.0] * len(model.GeometryEngine) + if engines is None: + engines = model.GeometryEngine + engines = list(engines) + width_ratios = [1.0] * len(engines) if colorbar: width_ratios.append(0.1) - (fig, axx) = plt.subplots( + (fig, all_ax) = plt.subplots( ncols=len(width_ratios), layout="constrained", figsize=figsize, gridspec_kw=dict(width_ratios=width_ratios), ) result = {} - cbar: list[Any] = [False] * len(model.GeometryEngine) + all_cbar: list[Any] = [False] * len(engines) if colorbar: - cbar[:0] = [axx[-1]] + all_cbar[:0] = [all_ax[-1]] - for g, ax, cb in zip(model.GeometryEngine, axx, cbar): + for ax, g, cb in zip(all_ax, engines, all_cbar): try: result[g] = trace_image(ax, geometry=g, colorbar=cb) except Exception as e: @@ -333,4 +340,59 @@ def plot_all_geometry( return result +def centered_image( + center, + xdir, + outdir, + width: Union[float, tuple[float, float]], + **kwargs: Any, +) -> model.ImageInput: + """ + Create an ImageInput with a centered view based on the given parameters. + + Parameters + ---------- + center : array_like + The center coordinate (real space) of the image. + xdir : array_like + The direction along the rendered x-axis. + outdir : array_like + The direction out of the page in the result. + width : float or tuple of two floats or array_like with shape (2,) + If a single float is provided, the image is square and that value is + used for both the x (horizontal) and y (vertical) dimensions. If a + tuple or array-like with two elements is + provided, the first element specifies the width along the x-axis and + the second element specifies the width along the y-axis. + **kwargs + Additional keyword arguments passed to the ImageInput constructor. + + Returns + ------- + model.ImageInput + The input to ``visualize`` to generate the centered image. + """ + center = np.asarray(center) + xdir = np.asarray(xdir) + ydir = np.cross(outdir, xdir) + + if isinstance(width, float): + wx, wy = width, width + elif len(width) == 2: + wx, wy = width + else: + raise ValueError("width must be a float or a length-2 tuple") + + offset = xdir * (wx / 2) + ydir * (wy / 2) + lower_left = (center - offset).tolist() + upper_right = (center + offset).tolist() + + return model.ImageInput( + lower_left=lower_left, + upper_right=upper_right, + rightward=xdir.tolist(), + **kwargs, + ) + + _register_cmaps() diff --git a/pyproject.toml b/pyproject.toml index cf09ef1..1e056cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "pydantic~=2.0", "pydantic-settings", "matplotlib>=3.7", - "numpy", + "numpy>=1.20", "typer" ] diff --git a/test/test_visualize.py b/test/test_visualize.py index 99821a6..64fcf5b 100644 --- a/test/test_visualize.py +++ b/test/test_visualize.py @@ -66,3 +66,27 @@ def test_IdMapper(): assert_array_equal(img, np.array([2, 2, 2])) assert_array_equal(img.mask, [True, False, True]) assert vol == ["bar", "baz", "foo"] + + +def test_centered_image(): + centered_image = visualize.centered_image + kwargs = dict() + + # Test case 1: Square image with x projection along -y + result = centered_image( + center=[0, 0, 0], xdir=[0, -1, 0], outdir=[0, 0, 1], width=2.0, **kwargs + ) + assert result.lower_left == [-1.0, 1.0, 0.0] + assert result.upper_right == [1.0, -1.0, 0.0] + assert result.rightward == [0.0, -1.0, 0.0] + + # Test case 2: Rectangle image with xdir = [1, 0, 0] + result = centered_image( + center=[1, 0, 0], + xdir=[1, 0, 0], + outdir=[0, 0, -1], + width=(2.0, 4.0), + **kwargs, + ) + assert result.lower_left == [0.0, 2.0, 0.0] + assert result.upper_right == [2.0, -2.0, 0.0]