11import os
2+ from collections import defaultdict
3+ from copy import deepcopy
24from dataclasses import dataclass , field
35from pathlib import Path
46from typing import Any
1214
1315from .interface_definition import ImageViewerInterface
1416
17+ @dataclass
18+ class CatalogInfo :
19+ """
20+ A named tuple to hold information about a catalog.
21+ """
22+ style : dict [str , Any ] = field (default_factory = dict )
23+ data : Table | None = None
1524
1625@dataclass
1726class ImageViewer :
@@ -28,9 +37,8 @@ class ImageViewer:
2837 stretch_options : tuple = ("linear" , "log" , "sqrt" )
2938 autocut_options : tuple = ("minmax" , "zscale" , "asinh" , "percentile" , "histogram" )
3039 _cursor : str = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS [0 ]
31- _markers : dict [str , dict ] = field (default_factory = dict )
32- _catalogs : dict [str , Table ] = field (default_factory = dict )
33- _catalog_names : list [str ] = field (default_factory = list )
40+ # marker style will be stored in catalog table metadata
41+ _catalogs : dict [str , CatalogInfo ] = field (default_factory = dict )
3442 _default_marker_style : dict [str , Any ] = field (default_factory = dict )
3543 _cuts : str | tuple [float , float ] = (0 , 1 )
3644 _stretch : str = "linear"
@@ -47,24 +55,32 @@ def __post_init__(self):
4755 # This is a dictionary of marker sets. The keys are the names of the
4856 # marker sets, and the values are the tables containing the markers.
4957 self ._default_marker_style = dict (shape = "circle" , color = "yellow" , size = 10 )
50- self ._markers [None ] = self ._default_marker_style .copy ()
58+ self ._catalogs = defaultdict (CatalogInfo )
59+ self ._catalogs [None ].data = None
60+ self ._catalogs [None ].style = self ._default_marker_style .copy ()
5161
5262 def _user_catalog_labels (self ) -> list [str ]:
5363 """
5464 Get the user-defined catalog labels.
5565 """
56- return [label for label in self ._markers if label is not None ]
66+ return [label for label in self ._catalogs if label is not None ]
5767
58- def _resolve_catalog_label (self , catalog_label : str | None ) -> list [str ]:
68+ def _resolve_catalog_label (self , catalog_label : str | None ) -> str :
69+ """
70+ Figure out the catalog label if the user did not specify one. This
71+ is needed so that the user gets what they expect in the simple case
72+ where there is only one catalog loaded. In that case the user may
73+ or may not have actually specified a catalog label.
74+ """
5975 user_keys = self ._user_catalog_labels ()
6076 if catalog_label is None :
6177 match len (user_keys ):
6278 case 0 :
63- # No user-defined styles , so return the default style
79+ # No user-defined catalog labels , so return the default label.
6480 catalog_label = None
6581 case 1 :
66- # The user must have set a style , so return that instead of
67- # the default style , which live in the key None.
82+ # The user must have loaded a catalog , so return that instead of
83+ # the default label , which live in the key None.
6884 catalog_label = user_keys [0 ]
6985 case _:
7086 raise ValueError (
@@ -128,7 +144,7 @@ def get_catalog_style(self, catalog_label=None) -> dict[str, dict[str, Any]]:
128144 """
129145 catalog_label = self ._resolve_catalog_label (catalog_label )
130146
131- style = self ._markers [catalog_label ]
147+ style = self ._catalogs [catalog_label ]. style
132148 style ["catalog_label" ] = catalog_label
133149 return style
134150
@@ -160,7 +176,12 @@ def set_catalog_style(
160176 color = color if color else self ._default_marker_style ["color" ]
161177 size = size if size else self ._default_marker_style ["size" ]
162178
163- self ._markers [catalog_label ] = {
179+ catalog_label = self ._resolve_catalog_label (catalog_label )
180+
181+ if self ._catalogs [catalog_label ].data is None :
182+ raise ValueError ("Must load a catalog before setting a catalog style." )
183+
184+ self ._catalogs [catalog_label ].style = {
164185 "shape" : shape ,
165186 "color" : color ,
166187 "size" : size ,
@@ -242,8 +263,9 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
242263
243264 # Marker-related methods
244265 def load_catalog (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
245- skycoord_colname : str = 'coord' , use_skycoord : bool = False ,
246- catalog_label : str | None = None ) -> None :
266+ skycoord_colname : str = 'coord' , use_skycoord : bool = True ,
267+ catalog_label : str | None = None ,
268+ catalog_style : dict | None = None ) -> None :
247269 """
248270 Add markers to the image.
249271
@@ -263,7 +285,7 @@ def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
263285 is ``'coord'``.
264286 use_skycoord : bool, optional
265287 If `True`, the ``skycoord_colname`` column will be used to
266- get the marker positions. Default is `False`.
288+ get the marker positions.
267289 catalog_label : str, optional
268290 The name of the marker set to use. If not given, a unique
269291 name will be generated.
@@ -273,32 +295,39 @@ def load_catalog(self, table: Table, x_colname: str = 'x', y_colname: str = 'y',
273295 except KeyError :
274296 coords = None
275297
276- if use_skycoord :
277- if self ._wcs is not None :
298+ try :
299+ xy = (table [x_colname ], table [y_colname ])
300+ except KeyError :
301+ xy = None
302+
303+ to_add = deepcopy (table )
304+ if xy is None :
305+ if self ._wcs is not None and coords is not None :
278306 x , y = self ._wcs .world_to_pixel (coords )
307+ to_add [x_colname ] = x
308+ to_add [y_colname ] = y
279309 else :
280- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
281- else :
282- x = table [x_colname ]
283- y = table [y_colname ]
284-
285- if not coords and self ._wcs is not None :
286- coords = self ._wcs .pixel_to_world (x , y )
310+ to_add [x_colname ] = to_add [y_colname ] = None
287311
288- to_add = Table (
289- dict (
290- x = x ,
291- y = y ,
292- coord = coords if coords else [None ] * len (x ),
293- )
294- )
312+ if coords is None :
313+ if use_skycoord and self ._wcs is None :
314+ raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
315+ elif xy is not None and self ._wcs is not None :
316+ # If we have xy coordinates, convert them to sky coordinates
317+ coords = self ._wcs .pixel_to_world (xy [0 ], xy [1 ])
318+ to_add [skycoord_colname ] = coords
319+ else :
320+ to_add [skycoord_colname ] = None
295321
296322 catalog_label = self ._resolve_catalog_label (catalog_label )
297- if catalog_label in self ._catalogs :
298- marker_table = self ._markers [catalog_label ]
299- self ._markers [catalog_label ] = vstack ([marker_table , to_add ])
323+ if (
324+ catalog_label in self ._catalogs
325+ and self ._catalogs [catalog_label ].data is not None
326+ ):
327+ old_table = self ._catalogs [catalog_label ].data
328+ self ._catalogs [catalog_label ].data = vstack ([old_table , to_add ])
300329 else :
301- self ._markers [catalog_label ] = to_add
330+ self ._catalogs [catalog_label ]. data = to_add
302331
303332 def remove_catalog (self , catalog_label : str | None = None ) -> None :
304333 """
@@ -310,19 +339,24 @@ def remove_catalog(self, catalog_label: str | None = None) -> None:
310339 The name of the marker set to remove. If the value is ``"*"``,
311340 then all markers will be removed.
312341 """
313- if isinstance (catalog_label , str ):
314- if catalog_label in self ._markers :
315- del self ._markers [catalog_label ]
316- elif catalog_label == "*" :
317- self ._markers = {}
318- else :
319- raise ValueError (f"Marker name { catalog_label } not found." )
320- elif isinstance (catalog_label , list ):
321- for name in catalog_label :
322- if name in self ._markers :
323- del self ._markers [name ]
324- else :
325- raise ValueError (f"Marker name { name } not found." )
342+ if isinstance (catalog_label , list ):
343+ raise ValueError (
344+ "Cannot remove multiple catalogs from a list. Please specify "
345+ "a single catalog label or use '*' to remove all catalogs."
346+ )
347+ elif catalog_label == "*" :
348+ # If the user wants to remove all catalogs, we reset the
349+ # catalogs dictionary to an empty one.
350+ self ._catalogs = defaultdict (CatalogInfo )
351+ return
352+
353+ # Special cases are done, so we can resolve the catalog label
354+ catalog_label = self ._resolve_catalog_label (catalog_label )
355+
356+ try :
357+ del self ._catalogs [catalog_label ]
358+ except KeyError :
359+ raise ValueError (f"Marker name { catalog_label } not found." )
326360
327361 def get_catalog (self , x_colname : str = 'x' , y_colname : str = 'y' ,
328362 skycoord_colname : str = 'coord' ,
@@ -331,7 +365,7 @@ def get_catalog(self, x_colname: str = 'x', y_colname: str = 'y',
331365 # duplicated here.
332366 catalog_label = self ._resolve_catalog_label (catalog_label )
333367
334- result = self ._markers [catalog_label ] if catalog_label in self ._markers else Table (names = ["x" , "y" , "coord" , "marker name" ])
368+ result = self ._catalogs [catalog_label ]. data if catalog_label in self ._catalogs else Table (names = ["x" , "y" , "coord" , "marker name" ])
335369
336370 result .rename_columns (["x" , "y" , "coord" ], [x_colname , y_colname , skycoord_colname ])
337371
0 commit comments