Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion celerpy/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ class ImageInput(_Model):
rightward: Real3 = [1, 0, 0]
"Ray trace direction which points to the right in the image"

vertical_pixels: NonNegativeInt
vertical_pixels: NonNegativeInt = 512
"Number of pixels along the y axis"

horizontal_divisor: Optional[PositiveInt] = None
Expand Down
78 changes: 70 additions & 8 deletions celerpy/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import json
import re
import warnings
from collections.abc import Mapping, MutableSequence
from collections.abc import Iterable, Mapping, MutableSequence
from importlib.resources import files
from pathlib import Path
from subprocess import TimeoutExpired
from tempfile import NamedTemporaryFile
from typing import Any, Optional
from typing import Any, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -307,30 +307,92 @@ def __call__(


def plot_all_geometry(
trace_image: Imager, *, colorbar=True, figsize=None
trace_image: Imager,
*,
colorbar: bool = True,
figsize: Optional[tuple] = None,
engines: Optional[Iterable] = None,
) -> Mapping[model.GeometryEngine, Any]:
"""Convenience function for plotting all available geometry types."""
width_ratios = [1.0] * len(model.GeometryEngine)
if engines is None:
engines = model.GeometryEngine
engines = list(engines)
width_ratios = [1.0] * len(engines)
if colorbar:
width_ratios.append(0.1)

(fig, axx) = plt.subplots(
(fig, all_ax) = plt.subplots(
ncols=len(width_ratios),
layout="constrained",
figsize=figsize,
gridspec_kw=dict(width_ratios=width_ratios),
)
result = {}
cbar: list[Any] = [False] * len(model.GeometryEngine)
all_cbar: list[Any] = [False] * len(engines)
if colorbar:
cbar[:0] = [axx[-1]]
all_cbar[:0] = [all_ax[-1]]

for g, ax, cb in zip(model.GeometryEngine, axx, cbar):
for ax, g, cb in zip(all_ax, engines, all_cbar):
try:
result[g] = trace_image(ax, geometry=g, colorbar=cb)
except Exception as e:
warnings.warn(f"Failed to trace {g} geometry: {e!s}", stacklevel=1)
return result


def centered_image(
center,
xdir,
outdir,
width: Union[float, tuple[float, float]],
**kwargs: Any,
) -> model.ImageInput:
"""
Create an ImageInput with a centered view based on the given parameters.

Parameters
----------
center : array_like
The center coordinate (real space) of the image.
xdir : array_like
The direction along the rendered x-axis.
outdir : array_like
The direction out of the page in the result.
width : float or tuple of two floats or array_like with shape (2,)
If a single float is provided, the image is square and that value is
used for both the x (horizontal) and y (vertical) dimensions. If a
tuple or array-like with two elements is
provided, the first element specifies the width along the x-axis and
the second element specifies the width along the y-axis.
**kwargs
Additional keyword arguments passed to the ImageInput constructor.

Returns
-------
model.ImageInput
The input to ``visualize`` to generate the centered image.
"""
center = np.asarray(center)
xdir = np.asarray(xdir)
ydir = np.cross(outdir, xdir)

if isinstance(width, float):
wx, wy = width, width
elif len(width) == 2:
wx, wy = width
else:
raise ValueError("width must be a float or a length-2 tuple")

offset = xdir * (wx / 2) + ydir * (wy / 2)
lower_left = (center - offset).tolist()
upper_right = (center + offset).tolist()

return model.ImageInput(
lower_left=lower_left,
upper_right=upper_right,
rightward=xdir.tolist(),
**kwargs,
)


_register_cmaps()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dependencies = [
"pydantic~=2.0",
"pydantic-settings",
"matplotlib>=3.7",
"numpy",
"numpy>=1.20",
"typer"
]

Expand Down
24 changes: 24 additions & 0 deletions test/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,27 @@ def test_IdMapper():
assert_array_equal(img, np.array([2, 2, 2]))
assert_array_equal(img.mask, [True, False, True])
assert vol == ["bar", "baz", "foo"]


def test_centered_image():
centered_image = visualize.centered_image
kwargs = dict()

# Test case 1: Square image with x projection along -y
result = centered_image(
center=[0, 0, 0], xdir=[0, -1, 0], outdir=[0, 0, 1], width=2.0, **kwargs
)
assert result.lower_left == [-1.0, 1.0, 0.0]
assert result.upper_right == [1.0, -1.0, 0.0]
assert result.rightward == [0.0, -1.0, 0.0]

# Test case 2: Rectangle image with xdir = [1, 0, 0]
result = centered_image(
center=[1, 0, 0],
xdir=[1, 0, 0],
outdir=[0, 0, -1],
width=(2.0, 4.0),
**kwargs,
)
assert result.lower_left == [0.0, 2.0, 0.0]
assert result.upper_right == [2.0, -2.0, 0.0]