Skip to content

Commit 1d6d7d8

Browse files
authored
Add visualization utilities (#6)
* Add ability to specify custom geometry engines in plot_all * Add centered_image * Default pixels and tweak variable names
1 parent 8ee8417 commit 1d6d7d8

File tree

4 files changed

+96
-10
lines changed

4 files changed

+96
-10
lines changed

celerpy/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ class ImageInput(_Model):
8989
rightward: Real3 = [1, 0, 0]
9090
"Ray trace direction which points to the right in the image"
9191

92-
vertical_pixels: NonNegativeInt
92+
vertical_pixels: NonNegativeInt = 512
9393
"Number of pixels along the y axis"
9494

9595
horizontal_divisor: Optional[PositiveInt] = None

celerpy/visualize.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@
77
import json
88
import re
99
import warnings
10-
from collections.abc import Mapping, MutableSequence
10+
from collections.abc import Iterable, Mapping, MutableSequence
1111
from importlib.resources import files
1212
from pathlib import Path
1313
from subprocess import TimeoutExpired
1414
from tempfile import NamedTemporaryFile
15-
from typing import Any, Optional
15+
from typing import Any, Optional, Union
1616

1717
import matplotlib.pyplot as plt
1818
import numpy as np
@@ -307,30 +307,92 @@ def __call__(
307307

308308

309309
def plot_all_geometry(
310-
trace_image: Imager, *, colorbar=True, figsize=None
310+
trace_image: Imager,
311+
*,
312+
colorbar: bool = True,
313+
figsize: Optional[tuple] = None,
314+
engines: Optional[Iterable] = None,
311315
) -> Mapping[model.GeometryEngine, Any]:
312316
"""Convenience function for plotting all available geometry types."""
313-
width_ratios = [1.0] * len(model.GeometryEngine)
317+
if engines is None:
318+
engines = model.GeometryEngine
319+
engines = list(engines)
320+
width_ratios = [1.0] * len(engines)
314321
if colorbar:
315322
width_ratios.append(0.1)
316323

317-
(fig, axx) = plt.subplots(
324+
(fig, all_ax) = plt.subplots(
318325
ncols=len(width_ratios),
319326
layout="constrained",
320327
figsize=figsize,
321328
gridspec_kw=dict(width_ratios=width_ratios),
322329
)
323330
result = {}
324-
cbar: list[Any] = [False] * len(model.GeometryEngine)
331+
all_cbar: list[Any] = [False] * len(engines)
325332
if colorbar:
326-
cbar[:0] = [axx[-1]]
333+
all_cbar[:0] = [all_ax[-1]]
327334

328-
for g, ax, cb in zip(model.GeometryEngine, axx, cbar):
335+
for ax, g, cb in zip(all_ax, engines, all_cbar):
329336
try:
330337
result[g] = trace_image(ax, geometry=g, colorbar=cb)
331338
except Exception as e:
332339
warnings.warn(f"Failed to trace {g} geometry: {e!s}", stacklevel=1)
333340
return result
334341

335342

343+
def centered_image(
344+
center,
345+
xdir,
346+
outdir,
347+
width: Union[float, tuple[float, float]],
348+
**kwargs: Any,
349+
) -> model.ImageInput:
350+
"""
351+
Create an ImageInput with a centered view based on the given parameters.
352+
353+
Parameters
354+
----------
355+
center : array_like
356+
The center coordinate (real space) of the image.
357+
xdir : array_like
358+
The direction along the rendered x-axis.
359+
outdir : array_like
360+
The direction out of the page in the result.
361+
width : float or tuple of two floats or array_like with shape (2,)
362+
If a single float is provided, the image is square and that value is
363+
used for both the x (horizontal) and y (vertical) dimensions. If a
364+
tuple or array-like with two elements is
365+
provided, the first element specifies the width along the x-axis and
366+
the second element specifies the width along the y-axis.
367+
**kwargs
368+
Additional keyword arguments passed to the ImageInput constructor.
369+
370+
Returns
371+
-------
372+
model.ImageInput
373+
The input to ``visualize`` to generate the centered image.
374+
"""
375+
center = np.asarray(center)
376+
xdir = np.asarray(xdir)
377+
ydir = np.cross(outdir, xdir)
378+
379+
if isinstance(width, float):
380+
wx, wy = width, width
381+
elif len(width) == 2:
382+
wx, wy = width
383+
else:
384+
raise ValueError("width must be a float or a length-2 tuple")
385+
386+
offset = xdir * (wx / 2) + ydir * (wy / 2)
387+
lower_left = (center - offset).tolist()
388+
upper_right = (center + offset).tolist()
389+
390+
return model.ImageInput(
391+
lower_left=lower_left,
392+
upper_right=upper_right,
393+
rightward=xdir.tolist(),
394+
**kwargs,
395+
)
396+
397+
336398
_register_cmaps()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies = [
1313
"pydantic~=2.0",
1414
"pydantic-settings",
1515
"matplotlib>=3.7",
16-
"numpy",
16+
"numpy>=1.20",
1717
"typer"
1818
]
1919

test/test_visualize.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,27 @@ def test_IdMapper():
6666
assert_array_equal(img, np.array([2, 2, 2]))
6767
assert_array_equal(img.mask, [True, False, True])
6868
assert vol == ["bar", "baz", "foo"]
69+
70+
71+
def test_centered_image():
72+
centered_image = visualize.centered_image
73+
kwargs = dict()
74+
75+
# Test case 1: Square image with x projection along -y
76+
result = centered_image(
77+
center=[0, 0, 0], xdir=[0, -1, 0], outdir=[0, 0, 1], width=2.0, **kwargs
78+
)
79+
assert result.lower_left == [-1.0, 1.0, 0.0]
80+
assert result.upper_right == [1.0, -1.0, 0.0]
81+
assert result.rightward == [0.0, -1.0, 0.0]
82+
83+
# Test case 2: Rectangle image with xdir = [1, 0, 0]
84+
result = centered_image(
85+
center=[1, 0, 0],
86+
xdir=[1, 0, 0],
87+
outdir=[0, 0, -1],
88+
width=(2.0, 4.0),
89+
**kwargs,
90+
)
91+
assert result.lower_left == [0.0, 2.0, 0.0]
92+
assert result.upper_right == [2.0, -2.0, 0.0]

0 commit comments

Comments
 (0)