Skip to content

Commit 2a62d49

Browse files
committed
Implement load/get catalog in dummy viewer and update tests
1 parent ea41859 commit 2a62d49

File tree

2 files changed

+207
-127
lines changed

2 files changed

+207
-127
lines changed

src/astro_image_display_api/dummy_viewer.py

Lines changed: 118 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import os
2+
from collections import defaultdict
3+
from copy import deepcopy
24
from dataclasses import dataclass, field
35
from pathlib import Path
46
from typing import Any
@@ -12,6 +14,13 @@
1214

1315
from .interface_definition import ImageViewerInterface
1416

17+
@dataclass
18+
class CatalogInfo:
19+
"""
20+
A named tuple to hold information about a catalog.
21+
"""
22+
style: dict[str, Any] = field(default_factory=dict)
23+
data: Table | None = None
1524

1625
@dataclass
1726
class ImageViewer:
@@ -28,7 +37,8 @@ class ImageViewer:
2837
stretch_options: tuple = ("linear", "log", "sqrt")
2938
autocut_options: tuple = ("minmax", "zscale", "asinh", "percentile", "histogram")
3039
_cursor: str = ImageViewerInterface.ALLOWED_CURSOR_LOCATIONS[0]
31-
_markers: dict[str, dict] = field(default_factory=dict)
40+
# marker style will be stored in catalog table metadata
41+
_catalogs: dict[str, CatalogInfo] = field(default_factory=dict)
3242
_default_marker_style: dict[str, Any] = field(default_factory=dict)
3343
_cuts: str | tuple[float, float] = (0, 1)
3444
_stretch: str = "linear"
@@ -45,7 +55,39 @@ def __post_init__(self):
4555
# This is a dictionary of marker sets. The keys are the names of the
4656
# marker sets, and the values are the tables containing the markers.
4757
self._default_marker_style = dict(shape="circle", color="yellow", size=10)
48-
self._markers[None] = self._default_marker_style.copy()
58+
self._catalogs = defaultdict(CatalogInfo)
59+
self._catalogs[None].data = None
60+
self._catalogs[None].style = self._default_marker_style.copy()
61+
62+
def _user_catalog_labels(self) -> list[str]:
63+
"""
64+
Get the user-defined catalog labels.
65+
"""
66+
return [label for label in self._catalogs if label is not None]
67+
68+
def _resolve_catalog_label(self, catalog_label: str | None) -> str:
69+
"""
70+
Figure out the catalog label if the user did not specify one. This
71+
is needed so that the user gets what they expect in the simple case
72+
where there is only one catalog loaded. In that case the user may
73+
or may not have actually specified a catalog label.
74+
"""
75+
user_keys = self._user_catalog_labels()
76+
if catalog_label is None:
77+
match len(user_keys):
78+
case 0:
79+
# No user-defined catalog labels, so return the default label.
80+
catalog_label = None
81+
case 1:
82+
# The user must have loaded a catalog, so return that instead of
83+
# the default label, which live in the key None.
84+
catalog_label = user_keys[0]
85+
case _:
86+
raise ValueError(
87+
"Multiple catalog styles defined. Please specify a catalog_label to get the style."
88+
)
89+
90+
return catalog_label
4991

5092
@property
5193
def stretch(self) -> str:
@@ -100,22 +142,9 @@ def get_catalog_style(self, catalog_label=None) -> dict[str, dict[str, Any]]:
100142
dict
101143
The style for the catalog.
102144
"""
103-
user_keys = list(set(self._markers.keys()) - {None})
104-
if catalog_label is None:
105-
match len(user_keys):
106-
case 0:
107-
# No user-defined styles, so return the default style
108-
catalog_label = None
109-
case 1:
110-
# The user must have set a style, so return that instead of
111-
# the default style, which live in the key None.
112-
catalog_label = user_keys[0]
113-
case _:
114-
raise ValueError(
115-
"Multiple catalog styles defined. Please specify a catalog_label to get the style."
116-
)
145+
catalog_label = self._resolve_catalog_label(catalog_label)
117146

118-
style = self._markers[catalog_label]
147+
style = self._catalogs[catalog_label].style
119148
style["catalog_label"] = catalog_label
120149
return style
121150

@@ -147,7 +176,12 @@ def set_catalog_style(
147176
color = color if color else self._default_marker_style["color"]
148177
size = size if size else self._default_marker_style["size"]
149178

150-
self._markers[catalog_label] = {
179+
catalog_label = self._resolve_catalog_label(catalog_label)
180+
181+
if self._catalogs[catalog_label].data is None:
182+
raise ValueError("Must load a catalog before setting a catalog style.")
183+
184+
self._catalogs[catalog_label].style = {
151185
"shape": shape,
152186
"color": color,
153187
"size": size,
@@ -228,9 +262,10 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
228262
p.write_text("This is a dummy file. The viewer does not save anything.")
229263

230264
# Marker-related methods
231-
def add_markers(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
232-
skycoord_colname: str = 'coord', use_skycoord: bool = False,
233-
marker_name: str | None = None) -> None:
265+
def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
266+
skycoord_colname: str = 'coord', use_skycoord: bool = True,
267+
catalog_label: str | None = None,
268+
catalog_style: dict | None = None) -> None:
234269
"""
235270
Add markers to the image.
236271
@@ -250,8 +285,8 @@ def add_markers(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
250285
is ``'coord'``.
251286
use_skycoord : bool, optional
252287
If `True`, the ``skycoord_colname`` column will be used to
253-
get the marker positions. Default is `False`.
254-
marker_name : str, optional
288+
get the marker positions.
289+
catalog_label : str, optional
255290
The name of the marker set to use. If not given, a unique
256291
name will be generated.
257292
"""
@@ -260,105 +295,86 @@ def add_markers(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
260295
except KeyError:
261296
coords = None
262297

263-
if use_skycoord:
264-
if self._wcs is not None:
298+
try:
299+
xy = (table[x_colname], table[y_colname])
300+
except KeyError:
301+
xy = None
302+
303+
to_add = deepcopy(table)
304+
if xy is None:
305+
if self._wcs is not None and coords is not None:
265306
x, y = self._wcs.world_to_pixel(coords)
307+
to_add[x_colname] = x
308+
to_add[y_colname] = y
266309
else:
267-
raise ValueError("WCS is not set. Cannot convert to pixel coordinates.")
268-
else:
269-
x = table[x_colname]
270-
y = table[y_colname]
271-
272-
if not coords and self._wcs is not None:
273-
coords = self._wcs.pixel_to_world(x, y)
274-
275-
to_add = Table(
276-
dict(
277-
x=x,
278-
y=y,
279-
coord=coords if coords else [None] * len(x),
280-
)
281-
)
282-
to_add["marker name"] = marker_name
283-
284-
if marker_name in self._markers:
285-
marker_table = self._markers[marker_name]
286-
self._markers[marker_name] = vstack([marker_table, to_add])
310+
to_add[x_colname] = to_add[y_colname] = None
311+
312+
if coords is None:
313+
if use_skycoord and self._wcs is None:
314+
raise ValueError("Cannot use sky coordinates without a SkyCoord column or WCS.")
315+
elif xy is not None and self._wcs is not None:
316+
# If we have xy coordinates, convert them to sky coordinates
317+
coords = self._wcs.pixel_to_world(xy[0], xy[1])
318+
to_add[skycoord_colname] = coords
319+
else:
320+
to_add[skycoord_colname] = None
321+
322+
catalog_label = self._resolve_catalog_label(catalog_label)
323+
if (
324+
catalog_label in self._catalogs
325+
and self._catalogs[catalog_label].data is not None
326+
):
327+
old_table = self._catalogs[catalog_label].data
328+
self._catalogs[catalog_label].data = vstack([old_table, to_add])
287329
else:
288-
self._markers[marker_name] = to_add
330+
self._catalogs[catalog_label].data = to_add
289331

290-
def reset_markers(self) -> None:
291-
"""
292-
Remove all markers from the image.
293-
"""
294-
self._markers = {}
295-
296-
def remove_markers(self, marker_name: str | list[str] | None = None) -> None:
332+
def remove_catalog(self, catalog_label: str | None = None) -> None:
297333
"""
298334
Remove markers from the image.
299335
300336
Parameters
301337
----------
302338
marker_name : str, optional
303-
The name of the marker set to remove. If the value is ``"all"``,
339+
The name of the marker set to remove. If the value is ``"*"``,
304340
then all markers will be removed.
305341
"""
306-
if isinstance(marker_name, str):
307-
if marker_name in self._markers:
308-
del self._markers[marker_name]
309-
elif marker_name == "all":
310-
self._markers = {}
311-
else:
312-
raise ValueError(f"Marker name {marker_name} not found.")
313-
elif isinstance(marker_name, list):
314-
for name in marker_name:
315-
if name in self._markers:
316-
del self._markers[name]
317-
else:
318-
raise ValueError(f"Marker name {name} not found.")
319-
320-
def get_markers(self, x_colname: str = 'x', y_colname: str = 'y',
321-
skycoord_colname: str = 'coord',
322-
marker_name: str | list[str] | None = None) -> Table:
323-
"""
324-
Get the marker positions.
342+
if isinstance(catalog_label, list):
343+
raise ValueError(
344+
"Cannot remove multiple catalogs from a list. Please specify "
345+
"a single catalog label or use '*' to remove all catalogs."
346+
)
347+
elif catalog_label == "*":
348+
# If the user wants to remove all catalogs, we reset the
349+
# catalogs dictionary to an empty one.
350+
self._catalogs = defaultdict(CatalogInfo)
351+
return
325352

326-
Parameters
327-
----------
328-
x_colname : str, optional
329-
The name of the column containing the x positions. Default
330-
is ``'x'``.
331-
y_colname : str, optional
332-
The name of the column containing the y positions. Default
333-
is ``'y'``.
334-
skycoord_colname : str, optional
335-
The name of the column containing the sky coordinates. Default
336-
is ``'coord'``.
337-
marker_name : str or list of str, optional
338-
The name of the marker set to use. If that value is ``"all"``,
339-
then all markers will be returned.
353+
# Special cases are done, so we can resolve the catalog label
354+
catalog_label = self._resolve_catalog_label(catalog_label)
340355

341-
Returns
342-
-------
343-
table : `astropy.table.Table`
344-
The table containing the marker positions. If no markers match the
345-
``marker_name`` parameter, an empty table is returned.
346-
"""
347-
if isinstance(marker_name, str):
348-
if marker_name == "all":
349-
marker_name = self._markers.keys()
350-
else:
351-
marker_name = [marker_name]
352-
elif marker_name is None:
353-
marker_name = ["_default_marker"]
356+
try:
357+
del self._catalogs[catalog_label]
358+
except KeyError:
359+
raise ValueError(f"Marker name {catalog_label} not found.")
360+
361+
def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y',
362+
skycoord_colname: str = 'coord',
363+
catalog_label: str | None = None) -> Table:
364+
# Dostring is copied from the interface definition, so it is not
365+
# duplicated here.
366+
catalog_label = self._resolve_catalog_label(catalog_label)
354367

355-
to_stack = [self._markers[name] for name in marker_name if name in self._markers]
368+
result = self._catalogs[catalog_label].data if catalog_label in self._catalogs else Table(names=["x", "y", "coord"])
356369

357-
result = vstack(to_stack) if to_stack else Table(names=["x", "y", "coord", "marker name"])
358370
result.rename_columns(["x", "y", "coord"], [x_colname, y_colname, skycoord_colname])
359371

360372
return result
373+
get_catalog.__doc__ = ImageViewerInterface.get_catalog.__doc__
361374

375+
def get_catalog_names(self) -> list[str]:
376+
return list(self._user_catalog_labels())
377+
get_catalog_names.__doc__ = ImageViewerInterface.get_catalog_names.__doc__
362378

363379
# Methods that modify the view
364380
def center_on(self, point: tuple | SkyCoord):

0 commit comments

Comments
 (0)