Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 57 additions & 14 deletions celerpy/visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
from pathlib import Path
from subprocess import TimeoutExpired
from tempfile import NamedTemporaryFile
from typing import Any, Optional, Union
from typing import Any, NamedTuple, Optional, Union

import matplotlib.pyplot as plt
import numpy as np
from matplotlib import colormaps
from matplotlib.axes import Axes as mpl_Axes
from matplotlib.colors import BoundaryNorm, ListedColormap

from . import model, process
Expand All @@ -27,12 +28,39 @@
_re_ptr = re.compile(r"0x[0-9a-f]+")


class WrappingListedColormap(ListedColormap):
"""A ListedColormap that wraps around when the number of colors is exceeded.

When more colors are requested than available, this colormap will cycle
through the available colors and emit a warning.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._warned: bool = False

def __call__(self, X, *args, **kwargs):
X = np.asarray(X)
if not self._warned and (max_val := np.max(X)) >= self.N:
warnings.warn(
f"Color index {max_val} exceeds colormap size {self.N}. "
"Colors will be reused cyclically.",
stacklevel=1,
)
self._warned = True

# Wrap indices using modulo
X_wrapped = np.mod(X, self.N)
return super().__call__(X_wrapped, *args, **kwargs)


def _register_cmaps():
resources = files("celerpy._resources")
cmap = ListedColormap(
np.loadtxt(resources.joinpath("glasbey-light.txt")),
name="glasbey_light",
)
with resources.joinpath("glasbey-light.txt").open("r") as f:
cmap = WrappingListedColormap(
np.loadtxt(f),
name="glasbey_light",
)
try:
colormaps.register(cmap)
except ValueError as e:
Expand Down Expand Up @@ -97,10 +125,15 @@ class CelerGeo:
image: Optional[model.ImageParams]
volumes: dict[model.GeometryEngine, list[str]]

@classmethod
def with_setup(cls, *args, **kwargs):
"""Construct, forwarding args to ModelSetup."""
return cls(setup=model.ModelSetup(*args, **kwargs))

@classmethod
def from_filename(cls, path: Path):
"""Construct from a geometry filename and default other setup."""
return cls(model.ModelSetup(geometry_file=path))
return cls.with_setup(geometry_file=path)

def __init__(self, setup: model.ModelSetup):
# Create the process and attach stdin/stdout pipes
Expand Down Expand Up @@ -210,8 +243,15 @@ def __missing__(self, key: str):
return result


LabeledAxis = collections.namedtuple("LabeledAxis", ["label", "lo", "hi"])
LabeledAxes = collections.namedtuple("LabeledAxes", ["x", "y"])
class LabeledAxis(NamedTuple):
label: str
lo: float
hi: float


class LabeledAxes(NamedTuple):
x: LabeledAxis
y: LabeledAxis


def calc_image_axes(image: model.ImageParams) -> LabeledAxes:
Expand Down Expand Up @@ -254,10 +294,10 @@ def __init__(self, celer_geo, image: model.ImageInput):

def __call__(
self,
ax,
ax: mpl_Axes,
geometry: Optional[model.GeometryEngine] = None,
memspace: Optional[model.MemSpace] = None,
colorbar=None,
colorbar: Union[bool, None, mpl_Axes] = None,
) -> dict[str, Any]:
(trace_output, img) = self.celer_geo.trace(
self.image, geometry=geometry, memspace=memspace
Expand All @@ -268,9 +308,9 @@ def __call__(
(x, y) = self.axes

ax.set_xlabel(x.label)
ax.set_xlim([x.lo, x.hi])
ax.set_xlim((x.lo, x.hi))
ax.set_ylabel(y.label)
ax.set_ylim([y.lo, y.hi])
ax.set_ylim((y.lo, y.hi))
tr = trace_output.trace
ax.set_title(f"{tr.geometry.name} ({tr.memspace.name})")

Expand All @@ -279,7 +319,7 @@ def __call__(
norm = BoundaryNorm(np.arange(len(volumes) + 1), len(volumes) + 1)
im = ax.imshow(
img,
extent=[x.lo, x.hi, y.lo, y.hi],
extent=(x.lo, x.hi, y.lo, y.hi),
interpolation="none",
norm=norm,
cmap="glasbey_light",
Expand All @@ -292,11 +332,14 @@ def __call__(
if colorbar:
# Create colorbar
bounds = norm.boundaries
kwargs = {"ticks": bounds[:-1] + np.diff(bounds) / 2}
kwargs: dict[str, Any] = {
"ticks": bounds[:-1] + np.diff(bounds) / 2
}
if not isinstance(colorbar, bool):
# User can specify a new axis to place the colorbar
kwargs["cax"] = colorbar
fig = ax.get_figure()
assert fig is not None
cbar = fig.colorbar(im, **kwargs)
result["colorbar"] = cbar

Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ celerpy = "celerpy.cli:app"

[tool.mypy]
plugins = [
"numpy.typing.mypy_plugin",
"pydantic.mypy"
]

Expand Down
94 changes: 52 additions & 42 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,130 +8,140 @@ boltons==25.0.0
# via
# face
# glom
build==1.2.2.post1
build==1.3.0
# via celerpy (pyproject.toml)
cfgv==3.4.0
# via pre-commit
click==8.1.8
click==8.3.0
# via typer
contourpy==1.3.1
contourpy==1.3.3
# via matplotlib
coverage==7.8.0
coverage==7.10.7
# via pytest-cov
cycler==0.12.1
# via matplotlib
dapperdata==0.4.0
# via celerpy (pyproject.toml)
distlib==0.3.9
distlib==0.4.0
# via virtualenv
face==24.0.0
# via glom
filelock==3.18.0
filelock==3.19.1
# via virtualenv
fonttools==4.57.0
fonttools==4.60.1
# via matplotlib
glom==24.11.0
# via celerpy (pyproject.toml)
identify==2.6.9
identify==2.6.15
# via pre-commit
iniconfig==2.1.0
# via pytest
kiwisolver==1.4.8
kiwisolver==1.4.9
# via matplotlib
markdown-it-py==3.0.0
markdown-it-py==4.0.0
# via rich
matplotlib==3.10.1
matplotlib==3.10.6
# via celerpy (pyproject.toml)
mdurl==0.1.2
# via markdown-it-py
mypy==1.15.0
mypy==1.18.2
# via celerpy (pyproject.toml)
mypy-extensions==1.0.0
mypy-extensions==1.1.0
# via mypy
nodeenv==1.9.1
# via pre-commit
numpy==2.2.4
numpy==2.3.3
# via
# celerpy (pyproject.toml)
# contourpy
# matplotlib
packaging==24.2
packaging==25.0
# via
# build
# matplotlib
# pytest
pillow==11.1.0
pathspec==0.12.1
# via mypy
pillow==11.3.0
# via matplotlib
platformdirs==4.3.7
platformdirs==4.4.0
# via virtualenv
pluggy==1.5.0
# via pytest
pre-commit==4.2.0
pluggy==1.6.0
# via
# pytest
# pytest-cov
pre-commit==4.3.0
# via celerpy (pyproject.toml)
pydantic==2.11.2
pydantic==2.11.10
# via
# celerpy (pyproject.toml)
# dapperdata
# pydantic-settings
pydantic-core==2.33.1
pydantic-core==2.33.2
# via pydantic
pydantic-settings==2.8.1
pydantic-settings==2.11.0
# via
# celerpy (pyproject.toml)
# dapperdata
pygments==2.19.1
# via rich
pyparsing==3.2.3
pygments==2.19.2
# via
# pytest
# rich
pyparsing==3.2.5
# via matplotlib
pyproject-hooks==1.2.0
# via build
pytest==8.3.5
pytest==8.4.2
# via
# celerpy (pyproject.toml)
# pytest-cov
# pytest-pretty
pytest-cov==6.1.0
pytest-cov==7.0.0
# via celerpy (pyproject.toml)
pytest-pretty==1.2.0
pytest-pretty==1.3.0
# via celerpy (pyproject.toml)
python-dateutil==2.9.0.post0
# via matplotlib
python-dotenv==1.1.0
python-dotenv==1.1.1
# via pydantic-settings
pyyaml==6.0.2
pyyaml==6.0.3
# via pre-commit
rich==14.0.0
rich==14.1.0
# via
# pytest-pretty
# typer
ruamel-yaml==0.18.10
ruamel-yaml==0.18.15
# via
# celerpy (pyproject.toml)
# dapperdata
ruff==0.11.4
ruamel-yaml-clib==0.2.14
# via ruamel-yaml
ruff==0.13.3
# via celerpy (pyproject.toml)
shellingham==1.5.4
# via typer
six==1.17.0
# via python-dateutil
toml-sort==0.24.2
toml-sort==0.24.3
# via celerpy (pyproject.toml)
tomlkit==0.13.2
tomlkit==0.13.3
# via toml-sort
typer==0.15.2
typer==0.19.2
# via
# celerpy (pyproject.toml)
# dapperdata
typing-extensions==4.13.1
typing-extensions==4.15.0
# via
# mypy
# pydantic
# pydantic-core
# typer
# typing-inspection
typing-inspection==0.4.0
# via pydantic
uv==0.6.12
typing-inspection==0.4.2
# via
# pydantic
# pydantic-settings
uv==0.8.23
# via celerpy (pyproject.toml)
virtualenv==20.30.0
virtualenv==20.34.0
# via pre-commit
Loading