1+ import numbers
12import os
3+ from collections import defaultdict
4+ from copy import copy
25from dataclasses import dataclass , field
36from pathlib import Path
47from typing import Any
58
9+ from astropy import units as u
610from astropy .coordinates import SkyCoord
711from astropy .nddata import CCDData , NDData
812from astropy .table import Table , vstack
913from astropy .units import Quantity , get_physical_type
1014from astropy .wcs import WCS
15+ from astropy .wcs .utils import proj_plane_pixel_scales
1116from astropy .visualization import AsymmetricPercentileInterval , BaseInterval , BaseStretch , LinearStretch , ManualInterval
1217from numpy .typing import ArrayLike
1318
1419from .interface_definition import ImageViewerInterface
1520
21+ @dataclass
22+ class ViewportInfo :
23+ """
24+ Class to hold image and viewport information.
25+ """
26+ center : SkyCoord | tuple [numbers .Real , numbers .Real ] | None = None
27+ fov : float | Quantity | None = None
28+ wcs : WCS | None = None
1629
1730@dataclass
1831class ImageViewer :
@@ -28,7 +41,7 @@ class ImageViewer:
2841 zoom_level : float = 1
2942 _cursor : str = ImageViewerInterface .ALLOWED_CURSOR_LOCATIONS [0 ]
3043 marker : Any = "marker"
31- _cuts : BaseInterval | tuple [float , float ] = AsymmetricPercentileInterval (upper_percentile = 95 )
44+ _cuts : BaseInterval | tuple [numbers . Real , numbers . Real ] = AsymmetricPercentileInterval (upper_percentile = 95 )
3245 _stretch : BaseStretch = LinearStretch
3346 # viewer: Any
3447
@@ -46,7 +59,15 @@ class ImageViewer:
4659 _previous_marker : Any = ""
4760 _markers : dict [str , Table ] = field (default_factory = dict )
4861 _wcs : WCS | None = None
49- _center : tuple [float , float ] = (0.0 , 0.0 )
62+ _center : tuple [numbers .Real , numbers .Real ] = (0.0 , 0.0 )
63+
64+
65+ def __post_init__ (self ):
66+ # Set up the initial state of the viewer
67+ self ._images = defaultdict (ViewportInfo )
68+ self ._images [None ].center = None
69+ self ._images [None ].fov = None
70+ self ._images [None ].wcs = None
5071
5172 def get_stretch (self ) -> BaseStretch :
5273 return self ._stretch
@@ -59,7 +80,7 @@ def set_stretch(self, value: BaseStretch) -> None:
5980 def get_cuts (self ) -> tuple :
6081 return self ._cuts
6182
62- def set_cuts (self , value : tuple [float , float ] | BaseInterval ) -> None :
83+ def set_cuts (self , value : tuple [numbers . Real , numbers . Real ] | BaseInterval ) -> None :
6384 if isinstance (value , tuple ) and len (value ) == 2 :
6485 self ._cuts = ManualInterval (value [0 ], value [1 ])
6586 elif isinstance (value , BaseInterval ):
@@ -80,7 +101,42 @@ def cursor(self, value: str) -> None:
80101 # The methods, grouped loosely by purpose
81102
82103 # Methods for loading data
83- def load_image (self , file : str | os .PathLike | ArrayLike | NDData ) -> None :
104+ def _user_image_labels (self ) -> list [str ]:
105+ """
106+ Get the list of user-defined image labels.
107+
108+ Returns
109+ -------
110+ list of str
111+ The list of user-defined image labels.
112+ """
113+ return [label for label in self ._images if label is not None ]
114+
115+ def _resolve_image_label (self , image_label : str | None ) -> str :
116+ """
117+ Figure out the catalog label if the user did not specify one. This
118+ is needed so that the user gets what they expect in the simple case
119+ where there is only one catalog loaded. In that case the user may
120+ or may not have actually specified a catalog label.
121+ """
122+ user_keys = self ._user_image_labels ()
123+ if image_label is None :
124+ match len (user_keys ):
125+ case 0 :
126+ # No user-defined catalog labels, so return the default label.
127+ image_label = None
128+ case 1 :
129+ # The user must have loaded a catalog, so return that instead of
130+ # the default label, which live in the key None.
131+ image_label = user_keys [0 ]
132+ case _:
133+ raise ValueError (
134+ "Multiple catalog styles defined. Please specify a image_label to get the style."
135+ )
136+
137+ return image_label
138+
139+ def load_image (self , file : str | os .PathLike | ArrayLike | NDData , image_label : str | None = None ) -> None :
84140 """
85141 Load a FITS file into the viewer.
86142
@@ -89,32 +145,42 @@ def load_image(self, file: str | os.PathLike | ArrayLike | NDData) -> None:
89145 file : str or `astropy.io.fits.HDU`
90146 The FITS file to load. If a string, it can be a URL or a
91147 file path.
148+
149+ image_label : str, optional
150+ A label for the image.
92151 """
152+ image_label = self ._resolve_image_label (image_label )
153+
154+ # Delete the current viewport if it exists
155+ if image_label in self ._images :
156+ del self ._images [image_label ]
157+
93158 if isinstance (file , (str , os .PathLike )):
94159 if isinstance (file , str ):
95160 is_adsf = file .endswith (".asdf" )
96161 else :
97162 is_asdf = file .suffix == ".asdf"
98163 if is_asdf :
99- self ._load_asdf (file )
164+ self ._load_asdf (file , image_label )
100165 else :
101- self ._load_fits (file )
166+ self ._load_fits (file , image_label )
102167 elif isinstance (file , NDData ):
103- self ._load_nddata (file )
168+ self ._load_nddata (file , image_label )
104169 else :
105170 # Assume it is a 2D array
106- self ._load_array (file )
171+ self ._load_array (file , image_label )
107172
108- def _load_fits (self , file : str | os .PathLike ) -> None :
173+ def _load_fits (self , file : str | os .PathLike , image_label : str | None ) -> None :
109174 ccd = CCDData .read (file )
110- self ._wcs = ccd .wcs
111- self .image_height , self .image_width = ccd .shape
112- # Totally made up number...as currently defined, zoom_level means, esentially, ratio
113- # of image size to viewer size.
114- self .zoom_level = 1.0
115- self .center_on ((self .image_width / 2 , self .image_height / 2 ))
116-
117- def _load_array (self , array : ArrayLike ) -> None :
175+ height , width = ccd .shape
176+ self ._images [image_label ].wcs = ccd .wcs
177+ self .set_viewport (
178+ center = (width / 2 , height / 2 ),
179+ fov = max (ccd .shape ),
180+ image_label = image_label
181+ )
182+
183+ def _load_array (self , array : ArrayLike , image_label : str | None ) -> None :
118184 """
119185 Load a 2D array into the viewer.
120186
@@ -123,14 +189,15 @@ def _load_array(self, array: ArrayLike) -> None:
123189 array : array-like
124190 The array to load.
125191 """
126- self .image_height , self .image_width = array .shape
127- # Totally made up number...as currently defined, zoom_level means, esentially, ratio
128- # of image size to viewer size.
129- self .zoom_level = 1.0
130- self .center_on ((self .image_width / 2 , self .image_height / 2 ))
131-
192+ height , width = array .shape
193+ self ._images [image_label ].wcs = None # No WCS for raw arrays
194+ self .set_viewport (
195+ center = (width / 2 , height / 2 ),
196+ fov = max (array .shape ),
197+ image_label = image_label
198+ )
132199
133- def _load_nddata (self , data : NDData ) -> None :
200+ def _load_nddata (self , data : NDData , image_label : str | None ) -> None :
134201 """
135202 Load an `astropy.nddata.NDData` object into the viewer.
136203
@@ -139,15 +206,16 @@ def _load_nddata(self, data: NDData) -> None:
139206 data : `astropy.nddata.NDData`
140207 The NDData object to load.
141208 """
142- self ._wcs = data .wcs
209+ self ._images [ image_label ]. wcs = data .wcs
143210 # Not all NDDData objects have a shape, apparently
144- self .image_height , self .image_width = data .data .shape
145- # Totally made up number...as currently defined, zoom_level means, esentially, ratio
146- # of image size to viewer size.
147- self .zoom_level = 1.0
148- self .center_on ((self .image_width / 2 , self .image_height / 2 ))
211+ height , width = data .data .shape
212+ self .set_viewport (
213+ center = (width / 2 , height / 2 ),
214+ fov = max (data .data .shape ),
215+ image_label = image_label
216+ )
149217
150- def _load_asdf (self , asdf_file : str | os .PathLike ) -> None :
218+ def _load_asdf (self , asdf_file : str | os .PathLike , image_label : str | None ) -> None :
151219 """
152220 Not implementing some load types is fine.
153221 """
@@ -313,67 +381,94 @@ def get_markers(self, x_colname: str = 'x', y_colname: str = 'y',
313381
314382
315383 # Methods that modify the view
316- def center_on (self , point : tuple | SkyCoord ):
317- """
318- Center the view on the point.
319-
320- Parameters
321- ----------
322- tuple or `~astropy.coordinates.SkyCoord`
323- If tuple of ``(X, Y)`` is given, it is assumed
324- to be in data coordinates.
325- """
326- # currently there is no way to get the position of the center, but we may as well make
327- # note of it
328- if isinstance (point , SkyCoord ):
329- if self ._wcs is not None :
330- point = self ._wcs .world_to_pixel (point )
384+ def set_viewport (
385+ self , center : SkyCoord | tuple [numbers .Real , numbers .Real ] | None = None ,
386+ fov : Quantity | numbers .Real | None = None ,
387+ image_label : str | None = None
388+ ) -> None :
389+ image_label = self ._resolve_image_label (image_label )
390+
391+ # Get current center/fov, if any, so that the user may input only one of them
392+ # after the initial setup if they wish.
393+ current_viewport = copy (self ._images [image_label ])
394+ if center is None :
395+ center = current_viewport .center
396+ if fov is None :
397+ fov = current_viewport .fov
398+
399+ # If either center or fov is None these checks will raise an appropriate error
400+ if not isinstance (center , (SkyCoord , tuple )):
401+ raise TypeError ("Invalid value for center. Center must be a SkyCoord or tuple of (X, Y)." )
402+ if not isinstance (fov , (Quantity , numbers .Real )):
403+ raise TypeError ("Invalid value for fov. FOV must be a Quantity or float." )
404+
405+ # Check that the center and fov are compatible with the current image
406+ if self ._images [image_label ].wcs is None :
407+ if current_viewport .center is not None :
408+ # If there is a WCS either input is fine. If there is no WCS then we only
409+ # check wther the new center is the same type as the current center.
410+ if isinstance (center , SkyCoord ) and not isinstance (current_viewport .center , SkyCoord ):
411+ raise ValueError ("Center must be a SkyCoord for this image when WCS is not set." )
412+ elif isinstance (center , tuple ) and not isinstance (current_viewport .center , tuple ):
413+ raise ValueError ("Center must be a tuple of (X, Y) for this image when WCS is not set." )
414+ if current_viewport .fov is not None :
415+ if isinstance (fov , Quantity ) and not isinstance (current_viewport .fov , Quantity ):
416+ raise ValueError ("FOV must be a angular Quantity for this image when WCS is not set." )
417+ elif isinstance (fov , numbers .Real ) and not isinstance (current_viewport .fov , numbers .Real ):
418+ raise ValueError ("FOV must be a float for this image when WCS is set." )
419+
420+ # 😅 if we made it this far we should be able to handle the actual setting
421+ self ._images [image_label ].center = center
422+ self ._images [image_label ].fov = fov
423+
424+
425+ set_viewport .__doc__ = ImageViewerInterface .set_viewport .__doc__
426+
427+ def get_viewport (
428+ self , sky_or_pixel : str | None = None , image_label : str | None = None
429+ ) -> dict [str , Any ]:
430+ if sky_or_pixel not in (None , "sky" , "pixel" ):
431+ raise ValueError ("sky_or_pixel must be 'sky', 'pixel', or None." )
432+ image_label = self ._resolve_image_label (image_label )
433+
434+ viewport = self ._images [image_label ]
435+ if sky_or_pixel == "sky" :
436+ if isinstance (viewport .center , SkyCoord ):
437+ center = viewport .center
438+ elif isinstance (viewport .center , tuple ):
439+ # If the center is a tuple, we need to convert it to SkyCoord
440+ if viewport .wcs is None :
441+ raise ValueError ("WCS is not set. Cannot convert pixel coordinates to sky coordinates." )
442+ center = viewport .wcs .pixel_to_world (viewport .center [0 ], viewport .center [1 ])
443+ if isinstance (viewport .fov , Quantity ):
444+ fov = viewport .fov
445+ elif isinstance (viewport .fov , numbers .Real ):
446+ if viewport .wcs is None :
447+ raise ValueError ("WCS is not set. Cannot convert FOV to sky coordinates." )
448+ pixel_scale = proj_plane_pixel_scales (viewport .wcs )
449+ fov = pixel_scale * viewport .fov * u .degree
450+ else :
451+ # Pixel coordinates
452+ if isinstance (viewport .center , SkyCoord ):
453+ if viewport .wcs is None :
454+ raise ValueError ("WCS is not set. Cannot convert sky coordinates to pixel coordinates." )
455+ center = viewport .wcs .world_to_pixel (viewport .center )
331456 else :
332- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
333-
334- self ._center = point
335-
336- def offset_by (self , dx : float | Quantity , dy : float | Quantity ) -> None :
337- """
338- Move the center to a point that is given offset
339- away from the current center.
340-
341- Parameters
342- ----------
343- dx, dy : float or `~astropy.units.Quantity`
344- Offset value. Without a unit, assumed to be pixel offsets.
345- If a unit is attached, offset by pixel or sky is assumed from
346- the unit.
347- """
348- # Convert to quantity to make the rest of the processing uniform
349- dx = Quantity (dx )
350- dy = Quantity (dy )
351-
352- # This raises a UnitConversionError if the units are not compatible
353- dx .to (dy .unit )
354-
355- # Do we have an angle or pixel offset?
356- if get_physical_type (dx ) == "angle" :
357- # This is a sky offset
358- if self ._wcs is not None :
359- old_center_coord = self ._wcs .pixel_to_world (self ._center [0 ], self ._center [1 ])
360- new_center = old_center_coord .spherical_offsets_by (dx , dy )
361- self .center_on (new_center )
457+ center = viewport .center
458+ if isinstance (viewport .fov , Quantity ):
459+ if viewport .wcs is None :
460+ raise ValueError ("WCS is not set. Cannot convert FOV to pixel coordinates." )
461+ pixel_scale = proj_plane_pixel_scales (viewport .wcs )
462+ fov = viewport .fov / pixel_scale
362463 else :
363- raise ValueError ("WCS is not set. Cannot convert to pixel coordinates." )
364- else :
365- # This is a pixel offset
366- new_center = (self ._center [0 ] + dx .value , self ._center [1 ] + dy .value )
367- self .center_on (new_center )
464+ fov = viewport .fov
368465
369- def zoom (self , val ) -> None :
370- """
371- Zoom in or out by the given factor.
466+ return dict (
467+ center = center ,
468+ fov = fov ,
469+ wcs = viewport .wcs ,
470+ image_label = image_label
471+ )
372472
373- Parameters
374- ----------
375- val : int
376- The zoom level to zoom the image.
377- See `zoom_level`.
378- """
379- self .zoom_level *= val
473+
474+ get_viewport .__doc__ = ImageViewerInterface .get_viewport .__doc__
0 commit comments