Skip to content

Commit 344fde0

Browse files
committed
More implementation WIP dummy
1 parent 6d16f41 commit 344fde0

File tree

2 files changed

+171
-73
lines changed

2 files changed

+171
-73
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 82 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
from collections import defaultdict
3+
from copy import deepcopy
24
from dataclasses import dataclass, field
35
from pathlib import Path
46
from typing import Any
@@ -12,6 +14,13 @@
1214

1315
from .interface_definition import ImageViewerInterface
1416

17+
@dataclass
18+
class CatalogInfo:
19+
"""
20+
A named tuple to hold information about a catalog.
21+
"""
22+
style: dict[str, Any] = field(default_factory=dict)
23+
data: Table | None = None
1524

1625
@dataclass
1726
class ImageViewer:
@@ -28,9 +37,8 @@ class ImageViewer:
2837
stretch_options: tuple = ("linear", "log", "sqrt")
2938
autocut_options: tuple = ("minmax", "zscale", "asinh", "percentile", "histogram")
3039
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
31-
_markers: dict[str, dict] = field(default_factory=dict)
32-
_catalogs: dict[str, Table] = field(default_factory=dict)
33-
_catalog_names: list[str] = field(default_factory=list)
40+
# marker style will be stored in catalog table metadata
41+
_catalogs: dict[str, CatalogInfo] = field(default_factory=dict)
3442
_default_marker_style: dict[str, Any] = field(default_factory=dict)
3543
_cuts: str | tuple[float, float] = (0, 1)
3644
_stretch: str = "linear"
@@ -47,24 +55,32 @@ def __post_init__(self):
4755
# This is a dictionary of marker sets. The keys are the names of the
4856
# marker sets, and the values are the tables containing the markers.
4957
self._default_marker_style = dict(shape="circle", color="yellow", size=10)
50-
self._markers[None] = self._default_marker_style.copy()
58+
self._catalogs = defaultdict(CatalogInfo)
59+
self._catalogs[None].data = None
60+
self._catalogs[None].style = self._default_marker_style.copy()
5161

5262
def _user_catalog_labels(self) -> list[str]:
5363
"""
5464
Get the user-defined catalog labels.
5565
"""
56-
return [label for label in self._markers if label is not None]
66+
return [label for label in self._catalogs if label is not None]
5767

58-
def _resolve_catalog_label(self, catalog_label: str | None) -> list[str]:
68+
def _resolve_catalog_label(self, catalog_label: str | None) -> str:
69+
"""
70+
Figure out the catalog label if the user did not specify one. This
71+
is needed so that the user gets what they expect in the simple case
72+
where there is only one catalog loaded. In that case the user may
73+
or may not have actually specified a catalog label.
74+
"""
5975
user_keys = self._user_catalog_labels()
6076
if catalog_label is None:
6177
match len(user_keys):
6278
case 0:
63-
# No user-defined styles, so return the default style
79+
# No user-defined catalog labels, so return the default label.
6480
catalog_label = None
6581
case 1:
66-
# The user must have set a style, so return that instead of
67-
# the default style, which live in the key None.
82+
# The user must have loaded a catalog, so return that instead of
83+
# the default label, which live in the key None.
6884
catalog_label = user_keys[0]
6985
case _:
7086
raise ValueError(
@@ -128,7 +144,7 @@ def get_catalog_style(self, catalog_label=None) -> dict[str, dict[str, Any]]:
128144
"""
129145
catalog_label = self._resolve_catalog_label(catalog_label)
130146

131-
style = self._markers[catalog_label]
147+
style = self._catalogs[catalog_label].style
132148
style["catalog_label"] = catalog_label
133149
return style
134150

@@ -160,7 +176,12 @@ def set_catalog_style(
160176
color = color if color else self._default_marker_style["color"]
161177
size = size if size else self._default_marker_style["size"]
162178

163-
self._markers[catalog_label] = {
179+
catalog_label = self._resolve_catalog_label(catalog_label)
180+
181+
if self._catalogs[catalog_label].data is None:
182+
raise ValueError("Must load a catalog before setting a catalog style.")
183+
184+
self._catalogs[catalog_label].style = {
164185
"shape": shape,
165186
"color": color,
166187
"size": size,
@@ -242,8 +263,9 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
242263

243264
# Marker-related methods
244265
def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
245-
skycoord_colname: str = 'coord', use_skycoord: bool = False,
246-
catalog_label: str | None = None) -> None:
266+
skycoord_colname: str = 'coord', use_skycoord: bool = True,
267+
catalog_label: str | None = None,
268+
catalog_style: dict | None = None) -> None:
247269
"""
248270
Add markers to the image.
249271
@@ -263,7 +285,7 @@ def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
263285
is ``'coord'``.
264286
use_skycoord : bool, optional
265287
If `True`, the ``skycoord_colname`` column will be used to
266-
get the marker positions. Default is `False`.
288+
get the marker positions.
267289
catalog_label : str, optional
268290
The name of the marker set to use. If not given, a unique
269291
name will be generated.
@@ -273,32 +295,39 @@ def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
273295
except KeyError:
274296
coords = None
275297

276-
if use_skycoord:
277-
if self._wcs is not None:
298+
try:
299+
xy = (table[x_colname], table[y_colname])
300+
except KeyError:
301+
xy = None
302+
303+
to_add = deepcopy(table)
304+
if xy is None:
305+
if self._wcs is not None and coords is not None:
278306
x, y = self._wcs.world_to_pixel(coords)
307+
to_add[x_colname] = x
308+
to_add[y_colname] = y
279309
else:
280-
raise ValueError("WCS is not set. Cannot convert to pixel coordinates.")
281-
else:
282-
x = table[x_colname]
283-
y = table[y_colname]
284-
285-
if not coords and self._wcs is not None:
286-
coords = self._wcs.pixel_to_world(x, y)
310+
to_add[x_colname] = to_add[y_colname] = None
287311

288-
to_add = Table(
289-
dict(
290-
x=x,
291-
y=y,
292-
coord=coords if coords else [None] * len(x),
293-
)
294-
)
312+
if coords is None:
313+
if use_skycoord and self._wcs is None:
314+
raise ValueError("WCS is not set. Cannot convert to pixel coordinates.")
315+
elif xy is not None and self._wcs is not None:
316+
# If we have xy coordinates, convert them to sky coordinates
317+
coords = self._wcs.pixel_to_world(xy[0], xy[1])
318+
to_add[skycoord_colname] = coords
319+
else:
320+
to_add[skycoord_colname] = None
295321

296322
catalog_label = self._resolve_catalog_label(catalog_label)
297-
if catalog_label in self._catalogs:
298-
marker_table = self._markers[catalog_label]
299-
self._markers[catalog_label] = vstack([marker_table, to_add])
323+
if (
324+
catalog_label in self._catalogs
325+
and self._catalogs[catalog_label].data is not None
326+
):
327+
old_table = self._catalogs[catalog_label].data
328+
self._catalogs[catalog_label].data = vstack([old_table, to_add])
300329
else:
301-
self._markers[catalog_label] = to_add
330+
self._catalogs[catalog_label].data = to_add
302331

303332
def remove_catalog(self, catalog_label: str | None = None) -> None:
304333
"""
@@ -310,19 +339,24 @@ def remove_catalog(self, catalog_label: str | None = None) -> None:
310339
The name of the marker set to remove. If the value is ``"*"``,
311340
then all markers will be removed.
312341
"""
313-
if isinstance(catalog_label, str):
314-
if catalog_label in self._markers:
315-
del self._markers[catalog_label]
316-
elif catalog_label == "*":
317-
self._markers = {}
318-
else:
319-
raise ValueError(f"Marker name {catalog_label} not found.")
320-
elif isinstance(catalog_label, list):
321-
for name in catalog_label:
322-
if name in self._markers:
323-
del self._markers[name]
324-
else:
325-
raise ValueError(f"Marker name {name} not found.")
342+
if isinstance(catalog_label, list):
343+
raise ValueError(
344+
"Cannot remove multiple catalogs from a list. Please specify "
345+
"a single catalog label or use '*' to remove all catalogs."
346+
)
347+
elif catalog_label == "*":
348+
# If the user wants to remove all catalogs, we reset the
349+
# catalogs dictionary to an empty one.
350+
self._catalogs = defaultdict(CatalogInfo)
351+
return
352+
353+
# Special cases are done, so we can resolve the catalog label
354+
catalog_label = self._resolve_catalog_label(catalog_label)
355+
356+
try:
357+
del self._catalogs[catalog_label]
358+
except KeyError:
359+
raise ValueError(f"Marker name {catalog_label} not found.")
326360

327361
def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y',
328362
skycoord_colname: str = 'coord',
@@ -331,7 +365,7 @@ def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y',
331365
# duplicated here.
332366
catalog_label = self._resolve_catalog_label(catalog_label)
333367

334-
result = self._markers[catalog_label] if catalog_label in self._markers else Table(names=["x", "y", "coord", "marker name"])
368+
result = self._catalogs[catalog_label].data if catalog_label in self._catalogs else Table(names=["x", "y", "coord", "marker name"])
335369

336370
result.rename_columns(["x", "y", "coord"], [x_colname, y_colname, skycoord_colname])
337371

0 commit comments

Comments
 (0)