11import os
2+ from collections import defaultdict
3+ from copy import deepcopy
24from dataclasses import dataclass , field
35from pathlib import Path
46from typing import Any
1315
1416from .interface_definition import ImageViewerInterface
1517
18+ @dataclass
19+ class CatalogInfo :
20+ """
21+ Class to hold information about a catalog.
22+ """
23+ style : dict [str , Any ] = field (default_factory = dict )
24+ data : Table | None = None
1625
1726@dataclass
1827class ImageViewer :
@@ -27,27 +36,65 @@ class ImageViewer:
2736 image_height : int = 0
2837 zoom_level : float = 1
2938 _cursor : str = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS [0 ]
30- marker : Any = "marker"
3139 _cuts : BaseInterval | tuple [float , float ] = AsymmetricPercentileInterval (upper_percentile = 95 )
3240 _stretch : BaseStretch = LinearStretch
3341 # viewer: Any
3442
3543 # Allowed locations for cursor display
3644 ALLOWED_CURSOR_LOCATIONS : tuple = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS
3745
38- # List of marker names that are for internal use only
39- RESERVED_MARKER_SET_NAMES : tuple = ImageViewerInterface .RESERVED_MARKER_SET_NAMES
40-
41- # Default marker name for marking via API
42- DEFAULT_MARKER_NAME : str = ImageViewerInterface .DEFAULT_MARKER_NAME
43-
4446 # some internal variable for keeping track of viewer state
45- _interactive_marker_name : str = ""
46- _previous_marker : Any = ""
47- _markers : dict [str , Table ] = field (default_factory = dict )
4847 _wcs : WCS | None = None
4948 _center : tuple [float , float ] = (0.0 , 0.0 )
5049
50+ def __post_init__ (self ):
51+ # This is a dictionary of marker sets. The keys are the names of the
52+ # marker sets, and the values are the tables containing the markers.
53+ self ._catalogs = defaultdict (CatalogInfo )
54+ self ._catalogs [None ].data = None
55+ self ._catalogs [None ].style = self ._default_catalog_style .copy ()
56+
57+ def _user_catalog_labels (self ) -> list [str ]:
58+ """
59+ Get the user-defined catalog labels.
60+ """
61+ return [label for label in self ._catalogs if label is not None ]
62+
63+ def _resolve_catalog_label (self , catalog_label : str | None ) -> str :
64+ """
65+ Figure out the catalog label if the user did not specify one. This
66+ is needed so that the user gets what they expect in the simple case
67+ where there is only one catalog loaded. In that case the user may
68+ or may not have actually specified a catalog label.
69+ """
70+ user_keys = self ._user_catalog_labels ()
71+ if catalog_label is None :
72+ match len (user_keys ):
73+ case 0 :
74+ # No user-defined catalog labels, so return the default label.
75+ catalog_label = None
76+ case 1 :
77+ # The user must have loaded a catalog, so return that instead of
78+ # the default label, which live in the key None.
79+ catalog_label = user_keys [0 ]
80+ case _:
81+ raise ValueError (
82+ "Multiple catalog styles defined. Please specify a catalog_label to get the style."
83+ )
84+
85+ return catalog_label
86+
87+ @property
88+ def _default_catalog_style (self ) -> dict [str , Any ]:
89+ """
90+ The default style for the catalog markers.
91+ """
92+ return {
93+ "shape" : "circle" ,
94+ "color" : "red" ,
95+ "size" : 5 ,
96+ }
97+
5198 def get_stretch (self ) -> BaseStretch :
5299 return self ._stretch
53100
@@ -79,6 +126,62 @@ def cursor(self, value: str) -> None:
79126
80127 # The methods, grouped loosely by purpose
81128
129+ def get_catalog_style (self , catalog_label = None ) -> dict [str , Any ]:
130+ """
131+ Get the style for the catalog.
132+
133+ Parameters
134+ ----------
135+ catalog_label : str, optional
136+ The label of the catalog. Default is ``None``.
137+
138+ Returns
139+ -------
140+ dict
141+ The style for the catalog.
142+ """
143+ catalog_label = self ._resolve_catalog_label (catalog_label )
144+
145+ style = self ._catalogs [catalog_label ].style .copy ()
146+ style ["catalog_label" ] = catalog_label
147+ return style
148+
149+ def set_catalog_style (
150+ self ,
151+ catalog_label : str | None = None ,
152+ shape : str = "circle" ,
153+ color : str = "red" ,
154+ size : float = 5 ,
155+ ** kwargs
156+ ) -> None :
157+ """
158+ Set the style for the catalog.
159+
160+ Parameters
161+ ----------
162+ catalog_label : str, optional
163+ The label of the catalog.
164+ shape : str, optional
165+ The shape of the markers.
166+ color : str, optional
167+ The color of the markers.
168+ size : float, optional
169+ The size of the markers.
170+ **kwargs
171+ Additional keyword arguments to pass to the marker style.
172+ """
173+ catalog_label = self ._resolve_catalog_label (catalog_label )
174+
175+ if self ._catalogs [catalog_label ].data is None :
176+ raise ValueError ("Must load a catalog before setting a catalog style." )
177+
178+ self ._catalogs [catalog_label ].style = dict (
179+ shape = shape ,
180+ color = color ,
181+ size = size ,
182+ ** kwargs
183+ )
184+
82185 # Methods for loading data
83186 def load_image (self , file : str | os .PathLike | ArrayLike | NDData ) -> None :
84187 """
@@ -175,142 +278,108 @@ def save(self, filename: str | os.PathLike, overwrite: bool = False) -> None:
175278 p .write_text ("This is a dummy file. The viewer does not save anything." )
176279
177280 # Marker-related methods
178- def add_markers (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
281+ def load_catalog (self , table : Table , x_colname : str = 'x' , y_colname : str = 'y' ,
179282 skycoord_colname : str = 'coord' , use_skycoord : bool = False ,
180- marker_name : str | None = None ) -> None :
181- """
182- Add markers to the image.
183-
184- Parameters
185- ----------
186- table : `astropy.table.Table`
187- The table containing the marker positions.
188- x_colname : str, optional
189- The name of the column containing the x positions. Default
190- is ``'x'``.
191- y_colname : str, optional
192- The name of the column containing the y positions. Default
193- is ``'y'``.
194- skycoord_colname : str, optional
195- The name of the column containing the sky coordinates. If
196- given, the ``use_skycoord`` parameter is ignored. Default
197- is ``'coord'``.
198- use_skycoord : bool, optional
199- If `True`, the ``skycoord_colname`` column will be used to
200- get the marker positions. Default is `False`.
201- marker_name : str, optional
202- The name of the marker set to use. If not given, a unique
203- name will be generated.
204- """
283+ catalog_label : str | None = None ,
284+ catalog_style : dict | None = None ) -> None :
205285 try :
206286 coords = table [skycoord_colname ]
207287 except KeyError :
208288 coords = None
209289
210- if use_skycoord :
211- if self ._wcs is not None :
290+ try :
291+ xy = (table [x_colname ], table [y_colname ])
292+ except KeyError :
293+ xy = None
294+
295+ to_add = deepcopy (table )
296+ if xy is None :
297+ if self ._wcs is not None and coords is not None :
212298 x , y = self ._wcs .world_to_pixel (coords )
299+ to_add [x_colname ] = x
300+ to_add [y_colname ] = y
213301 else :
214- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
302+ to_add [x_colname ] = to_add [y_colname ] = None
303+
304+ if coords is None :
305+ if use_skycoord and self ._wcs is None :
306+ raise ValueError ("Cannot use sky coordinates without a SkyCoord column or WCS." )
307+ elif xy is not None and self ._wcs is not None :
308+ # If we have xy coordinates, convert them to sky coordinates
309+ coords = self ._wcs .pixel_to_world (xy [0 ], xy [1 ])
310+ to_add [skycoord_colname ] = coords
311+ else :
312+ to_add [skycoord_colname ] = None
313+
314+ catalog_label = self ._resolve_catalog_label (catalog_label )
315+
316+ # Either set new data or append to existing data
317+ if (
318+ catalog_label in self ._catalogs
319+ and self ._catalogs [catalog_label ].data is not None
320+ ):
321+ # If the catalog already exists, we append to it
322+ old_table = self ._catalogs [catalog_label ].data
323+ self ._catalogs [catalog_label ].data = vstack ([old_table , to_add ])
215324 else :
216- x = table [x_colname ]
217- y = table [y_colname ]
218-
219- if not coords and self ._wcs is not None :
220- coords = self ._wcs .pixel_to_world (x , y )
221-
222- if marker_name in self .RESERVED_MARKER_SET_NAMES :
223- raise ValueError (f"Marker name { marker_name } not allowed." )
325+ # If the catalog does not exist, we create a new one
326+ self ._catalogs [catalog_label ].data = to_add
224327
225- marker_name = marker_name if marker_name else self .DEFAULT_MARKER_NAME
328+ # Ensure a catalog always has a style
329+ if catalog_style is None :
330+ if not self ._catalogs [catalog_label ].style :
331+ catalog_style = self ._default_catalog_style .copy ()
226332
227- to_add = Table (
228- dict (
229- x = x ,
230- y = y ,
231- coord = coords if coords else [None ] * len (x ),
232- )
233- )
234- to_add ["marker name" ] = marker_name
333+ self ._catalogs [catalog_label ].style = catalog_style
235334
236- if marker_name in self ._markers :
237- marker_table = self ._markers [marker_name ]
238- self ._markers [marker_name ] = vstack ([marker_table , to_add ])
239- else :
240- self ._markers [marker_name ] = to_add
335+ load_catalog .__doc__ = ImageViewerInterface .load_catalog .__doc__
241336
242- def reset_markers (self ) -> None :
243- """
244- Remove all markers from the image.
245- """
246- self ._markers = {}
247-
248- def remove_markers (self , marker_name : str | list [str ] | None = None ) -> None :
337+ def remove_catalog (self , catalog_label : str | None = None ) -> None :
249338 """
250339 Remove markers from the image.
251340
252341 Parameters
253342 ----------
254343 marker_name : str, optional
255- The name of the marker set to remove. If the value is ``"all "``,
344+ The name of the marker set to remove. If the value is ``"* "``,
256345 then all markers will be removed.
257346 """
258- if isinstance (marker_name , str ):
259- if marker_name in self ._markers :
260- del self ._markers [marker_name ]
261- elif marker_name == "all" :
262- self ._markers = {}
263- else :
264- raise ValueError (f"Marker name { marker_name } not found." )
265- elif isinstance (marker_name , list ):
266- for name in marker_name :
267- if name in self ._markers :
268- del self ._markers [name ]
269- else :
270- raise ValueError (f"Marker name { name } not found." )
271-
272- def get_markers (self , x_colname : str = 'x' , y_colname : str = 'y' ,
273- skycoord_colname : str = 'coord' ,
274- marker_name : str | list [str ] | None = None ) -> Table :
275- """
276- Get the marker positions.
347+ if isinstance (catalog_label , list ):
348+ raise ValueError (
349+ "Cannot remove multiple catalogs from a list. Please specify "
350+ "a single catalog label or use '*' to remove all catalogs."
351+ )
352+ elif catalog_label == "*" :
353+ # If the user wants to remove all catalogs, we reset the
354+ # catalogs dictionary to an empty one.
355+ self ._catalogs = defaultdict (CatalogInfo )
356+ return
277357
278- Parameters
279- ----------
280- x_colname : str, optional
281- The name of the column containing the x positions. Default
282- is ``'x'``.
283- y_colname : str, optional
284- The name of the column containing the y positions. Default
285- is ``'y'``.
286- skycoord_colname : str, optional
287- The name of the column containing the sky coordinates. Default
288- is ``'coord'``.
289- marker_name : str or list of str, optional
290- The name of the marker set to use. If that value is ``"all"``,
291- then all markers will be returned.
358+ # Special cases are done, so we can resolve the catalog label
359+ catalog_label = self ._resolve_catalog_label (catalog_label )
292360
293- Returns
294- -------
295- table : `astropy.table.Table`
296- The table containing the marker positions. If no markers match the
297- ``marker_name`` parameter, an empty table is returned.
298- """
299- if isinstance (marker_name , str ):
300- if marker_name == "all" :
301- marker_name = self ._markers .keys ()
302- else :
303- marker_name = [marker_name ]
304- elif marker_name is None :
305- marker_name = [self .DEFAULT_MARKER_NAME ]
361+ try :
362+ del self ._catalogs [catalog_label ]
363+ except KeyError :
364+ raise ValueError (f"Catalog label { catalog_label } not found." )
365+
366+ def get_catalog (self , x_colname : str = 'x' , y_colname : str = 'y' ,
367+ skycoord_colname : str = 'coord' ,
368+ catalog_label : str | None = None ) -> Table :
369+ # Dostring is copied from the interface definition, so it is not
370+ # duplicated here.
371+ catalog_label = self ._resolve_catalog_label (catalog_label )
306372
307- to_stack = [ self ._markers [ name ] for name in marker_name if name in self . _markers ]
373+ result = self ._catalogs [ catalog_label ]. data if catalog_label in self . _catalogs else Table ( names = [ "x" , "y" , "coord" ])
308374
309- result = vstack (to_stack ) if to_stack else Table (names = ["x" , "y" , "coord" , "marker name" ])
310375 result .rename_columns (["x" , "y" , "coord" ], [x_colname , y_colname , skycoord_colname ])
311376
312377 return result
378+ get_catalog .__doc__ = ImageViewerInterface .get_catalog .__doc__
313379
380+ def get_catalog_names (self ) -> list [str ]:
381+ return list (self ._user_catalog_labels ())
382+ get_catalog_names .__doc__ = ImageViewerInterface .get_catalog_names .__doc__
314383
315384 # Methods that modify the view
316385 def center_on (self , point : tuple | SkyCoord ):
0 commit comments