11import os
2+ from collections import defaultdict
3+ from copy import deepcopy
24from dataclasses import dataclass , field
35from pathlib import Path
46from typing import Any
1214
1315from .interface_definition import ImageViewerInterface
1416
17+ @dataclass
18+ class CatalogInfo :
19+ """
20+ A named tuple to hold information about a catalog.
21+ """
22+ style : dict [str , Any ] = field (default_factory = dict )
23+ data : Table | None = None
1524
1625@dataclass
1726class ImageViewer :
@@ -28,7 +37,8 @@ class ImageViewer:
2837 stretch_options : tuple = ("linear" , "log" , "sqrt" )
2938 autocut_options : tuple = ("minmax" , "zscale" , "asinh" , "percentile" , "histogram" )
3039 _cursor : str = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS [0 ]
31- _markers : dict [str , dict ] = field (default_factory = dict )
40+ # marker style will be stored in catalog table metadata
41+ _catalogs : dict [str , CatalogInfo ] = field (default_factory = dict )
3242 _default_marker_style : dict [str , Any ] = field (default_factory = dict )
3343 _cuts : str | tuple [float , float ] = (0 , 1 )
3444 _stretch : str = "linear"
@@ -45,7 +55,39 @@ def __post_init__(self):
4555 # This is a dictionary of marker sets. The keys are the names of the
4656 # marker sets, and the values are the tables containing the markers.
4757 self ._default_marker_style = dict (shape = "circle" , color = "yellow" , size = 10 )
48- self ._markers [None ] = self ._default_marker_style .copy ()
58+ self ._catalogs = defaultdict (CatalogInfo )
59+ self ._catalogs [None ].data = None
60+ self ._catalogs [None ].style = self ._default_marker_style .copy ()
61+
62+ def _user_catalog_labels (self ) -> list [str ]:
63+ """
64+ Get the user-defined catalog labels.
65+ """
66+ return [label for label in self ._catalogs if label is not None ]
67+
68+ def _resolve_catalog_label (self , catalog_label : str | None ) -> str :
69+ """
70+ Figure out the catalog label if the user did not specify one. This
71+ is needed so that the user gets what they expect in the simple case
72+ where there is only one catalog loaded. In that case the user may
73+ or may not have actually specified a catalog label.
74+ """
75+ user_keys = self ._user_catalog_labels ()
76+ if catalog_label is None :
77+ match len (user_keys ):
78+ case 0 :
79+ # No user-defined catalog labels, so return the default label.
80+ catalog_label = None
81+ case 1 :
82+ # The user must have loaded a catalog, so return that instead of
83+ # the default label, which live in the key None.
84+ catalog_label = user_keys [0 ]
85+ case _:
86+ raise ValueError (
87+ "Multiple catalog styles defined. Please specify a catalog_label to get the style."
88+ )
89+
90+ return catalog_label
4991
5092 @property
5193 def stretch (self ) -> str :
@@ -100,22 +142,9 @@ def get_catalog_style(self, catalog_label=None) -> dict[str, dict[str, Any]]:
100142 dict
101143 The style for the catalog.
102144 """
103- user_keys = list (set (self ._markers .keys ()) - {None })
104- if catalog_label is None :
105- match len (user_keys ):
106- case 0 :
107- # No user-defined styles, so return the default style
108- catalog_label = None
109- case 1 :
110- # The user must have set a style, so return that instead of
111- # the default style, which live in the key None.
112- catalog_label = user_keys [0 ]
113- case _:
114- raise ValueError (
115- "Multiple catalog styles defined. Please specify a catalog_label to get the style."
116- )
145+ catalog_label = self ._resolve_catalog_label (catalog_label )
117146
118- style = self ._markers [catalog_label ]
147+ style = self ._catalogs [catalog_label ]. style
119148 style ["catalog_label" ] = catalog_label
120149 return style
121150
@@ -147,7 +176,12 @@ def set_catalog_style(
147176 color = color if color else self ._default_marker_style ["color" ]
148177 size = size if size else self ._default_marker_style ["size" ]
149178
150- self ._markers [catalog_label ] = {
179+ catalog_label = self ._resolve_catalog_label (catalog_label )
180+
181+ if self ._catalogs [catalog_label ].data is None :
182+ raise ValueError ("Must load a catalog before setting a catalog style." )
183+
184+ self ._catalogs [catalog_label ].style = {
151185 "shape" : shape ,
152186 "color" : color ,
153187 "size" : size ,
@@ -228,9 +262,10 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
228262 p .write_text ("This is a dummy file. The viewer does not save anything." )
229263
230264 # Marker-related methods
231- def add_markers (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
232- skycoord_colname : str = 'coord' , use_skycoord : bool = False ,
233- marker_name : str | None = None ) -> None :
265+ def load_catalog (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
266+ skycoord_colname : str = 'coord' , use_skycoord : bool = True ,
267+ catalog_label : str | None = None ,
268+ catalog_style : dict | None = None ) -> None :
234269 """
235270 Add markers to the image.
236271
@@ -250,8 +285,8 @@ def add_markers(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
250285 is ``'coord'``.
251286 use_skycoord : bool, optional
252287 If `True`, the ``skycoord_colname`` column will be used to
253- get the marker positions. Default is `False`.
254- marker_name : str, optional
288+ get the marker positions.
289+ catalog_label : str, optional
255290 The name of the marker set to use. If not given, a unique
256291 name will be generated.
257292 """
@@ -260,105 +295,86 @@ def add_markers(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
260295 except KeyError :
261296 coords = None
262297
263- if use_skycoord :
264- if self ._wcs is not None :
298+ try :
299+ xy = (table [x_colname ], table [y_colname ])
300+ except KeyError :
301+ xy = None
302+
303+ to_add = deepcopy (table )
304+ if xy is None :
305+ if self ._wcs is not None and coords is not None :
265306 x , y = self ._wcs .world_to_pixel (coords )
307+ to_add [x_colname ] = x
308+ to_add [y_colname ] = y
266309 else :
267- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
268- else :
269- x = table [x_colname ]
270- y = table [y_colname ]
271-
272- if not coords and self ._wcs is not None :
273- coords = self ._wcs .pixel_to_world (x , y )
274-
275- to_add = Table (
276- dict (
277- x = x ,
278- y = y ,
279- coord = coords if coords else [None ] * len (x ),
280- )
281- )
282- to_add ["marker name" ] = marker_name
283-
284- if marker_name in self ._markers :
285- marker_table = self ._markers [marker_name ]
286- self ._markers [marker_name ] = vstack ([marker_table , to_add ])
310+ to_add [x_colname ] = to_add [y_colname ] = None
311+
312+ if coords is None :
313+ if use_skycoord and self ._wcs is None :
314+ raise ValueError ("Cannot use sky coordinates without a SkyCoord column or WCS." )
315+ elif xy is not None and self ._wcs is not None :
316+ # If we have xy coordinates, convert them to sky coordinates
317+ coords = self ._wcs .pixel_to_world (xy [0 ], xy [1 ])
318+ to_add [skycoord_colname ] = coords
319+ else :
320+ to_add [skycoord_colname ] = None
321+
322+ catalog_label = self ._resolve_catalog_label (catalog_label )
323+ if (
324+ catalog_label in self ._catalogs
325+ and self ._catalogs [catalog_label ].data is not None
326+ ):
327+ old_table = self ._catalogs [catalog_label ].data
328+ self ._catalogs [catalog_label ].data = vstack ([old_table , to_add ])
287329 else :
288- self ._markers [ marker_name ] = to_add
330+ self ._catalogs [ catalog_label ]. data = to_add
289331
290- def reset_markers (self ) -> None :
291- """
292- Remove all markers from the image.
293- """
294- self ._markers = {}
295-
296- def remove_markers (self , marker_name : str | list [str ] | None = None ) -> None :
332+ def remove_catalog (self , catalog_label : str | None = None ) -> None :
297333 """
298334 Remove markers from the image.
299335
300336 Parameters
301337 ----------
302338 marker_name : str, optional
303- The name of the marker set to remove. If the value is ``"all "``,
339+ The name of the marker set to remove. If the value is ``"* "``,
304340 then all markers will be removed.
305341 """
306- if isinstance (marker_name , str ):
307- if marker_name in self ._markers :
308- del self ._markers [marker_name ]
309- elif marker_name == "all" :
310- self ._markers = {}
311- else :
312- raise ValueError (f"Marker name { marker_name } not found." )
313- elif isinstance (marker_name , list ):
314- for name in marker_name :
315- if name in self ._markers :
316- del self ._markers [name ]
317- else :
318- raise ValueError (f"Marker name { name } not found." )
319-
320- def get_markers (self , x_colname : str = 'x' , y_colname : str = 'y' ,
321- skycoord_colname : str = 'coord' ,
322- marker_name : str | list [str ] | None = None ) -> Table :
323- """
324- Get the marker positions.
342+ if isinstance (catalog_label , list ):
343+ raise ValueError (
344+ "Cannot remove multiple catalogs from a list. Please specify "
345+ "a single catalog label or use '*' to remove all catalogs."
346+ )
347+ elif catalog_label == "*" :
348+ # If the user wants to remove all catalogs, we reset the
349+ # catalogs dictionary to an empty one.
350+ self ._catalogs = defaultdict (CatalogInfo )
351+ return
325352
326- Parameters
327- ----------
328- x_colname : str, optional
329- The name of the column containing the x positions. Default
330- is ``'x'``.
331- y_colname : str, optional
332- The name of the column containing the y positions. Default
333- is ``'y'``.
334- skycoord_colname : str, optional
335- The name of the column containing the sky coordinates. Default
336- is ``'coord'``.
337- marker_name : str or list of str, optional
338- The name of the marker set to use. If that value is ``"all"``,
339- then all markers will be returned.
353+ # Special cases are done, so we can resolve the catalog label
354+ catalog_label = self ._resolve_catalog_label (catalog_label )
340355
341- Returns
342- -------
343- table : `astropy.table.Table`
344- The table containing the marker positions. If no markers match the
345- ``marker_name`` parameter, an empty table is returned.
346- """
347- if isinstance (marker_name , str ):
348- if marker_name == "all" :
349- marker_name = self ._markers .keys ()
350- else :
351- marker_name = [marker_name ]
352- elif marker_name is None :
353- marker_name = ["_default_marker" ]
356+ try :
357+ del self ._catalogs [catalog_label ]
358+ except KeyError :
359+ raise ValueError (f"Marker name { catalog_label } not found." )
360+
361+ def get_catalog (self , x_colname : str = 'x' , y_colname : str = 'y' ,
362+ skycoord_colname : str = 'coord' ,
363+ catalog_label : str | None = None ) -> Table :
364+ # Dostring is copied from the interface definition, so it is not
365+ # duplicated here.
366+ catalog_label = self ._resolve_catalog_label (catalog_label )
354367
355- to_stack = [ self ._markers [ name ] for name in marker_name if name in self . _markers ]
368+ result = self ._catalogs [ catalog_label ]. data if catalog_label in self . _catalogs else Table ( names = [ "x" , "y" , "coord" ])
356369
357- result = vstack (to_stack ) if to_stack else Table (names = ["x" , "y" , "coord" , "marker name" ])
358370 result .rename_columns (["x" , "y" , "coord" ], [x_colname , y_colname , skycoord_colname ])
359371
360372 return result
373+ get_catalog .__doc__ = ImageViewerInterface .get_catalog .__doc__
361374
375+ def get_catalog_names (self ) -> list [str ]:
376+ return list (self ._user_catalog_labels ())
377+ get_catalog_names .__doc__ = ImageViewerInterface .get_catalog_names .__doc__
362378
363379 # Methods that modify the view
364380 def center_on (self , point : tuple | SkyCoord ):
0 commit comments