22import astropy .wcs
33import equinox as eqx
44import jax .numpy as jnp
5+ import numpy as np
56from astropy .coordinates import SkyCoord
67
78from .bbox import Box
@@ -27,10 +28,14 @@ class Frame(eqx.Module):
2728 def __init__ (self , bbox , psf = None , wcs = None , channels = None ):
2829 self .bbox = bbox
2930 self .psf = psf
31+ if wcs is None :
32+ wcs = astropy .wcs .WCS (naxis = 2 ) # Dummy WCS
33+ wcs ._naxis = bbox .spatial .shape [::- 1 ]
34+ wcs .wcs .ctype = ["RA---TAN" , "DEC--TAN" ]
3035 self .wcs = wcs
36+
3137 if channels is None :
3238 channels = list (range (bbox .shape [0 ]))
33-
3439 self .channels = channels
3540
3641 def __hash__ (self ):
@@ -47,14 +52,10 @@ def pixel_size(self):
4752
4853 Returns
4954 -------
50- float
51- Pixel size in arcsec, averaged over x and y direction
55+ float, astropy.units.quantity.Quantity
56+ Pixel size in units of the WCS sky coordinates
5257 """
53- if self .wcs is not None :
54- # return get_pixel_size(get_affine(self.wcs)) * 60 * 60 # in arcsec
55- return get_scale (self .wcs ).mean () * 60 * 60 # in arcsec
56- else :
57- return 1
58+ return get_scale (self .wcs )
5859
5960 def get_pixel (self , pos ):
6061 """Get the sky coordinates from a world coordinate
@@ -69,7 +70,6 @@ def get_pixel(self, pos):
6970 pixel coordinates in the model frame
7071 """
7172 if isinstance (pos , SkyCoord ):
72- assert self .wcs is not None , "SkyCoord can only be converted with valid WCS"
7373 wcs_ = self .wcs .celestial # only use celestial portion
7474 pixel = jnp .asarray (pos .to_pixel (wcs_ ), dtype = "float32" ).T
7575 return pixel [..., ::- 1 ]
@@ -85,14 +85,12 @@ def get_sky_coord(self, pos):
8585
8686 Returns
8787 ----------
88- astropy.coordinates.SkyCoord if WCS is set, otherwise pos
88+ astropy.coordinates.SkyCoord
8989 """
90- if self .wcs is not None :
91- pixels = pos .reshape (- 1 , 2 )
92- wcs = self .wcs .celestial # only use celestial portion
93- sky_coord = SkyCoord .from_pixel (pixels [:, 1 ], pixels [:, 0 ], wcs )
94- return sky_coord
95- return pos
90+ pixels = pos .reshape (- 1 , 2 )
91+ wcs = self .wcs .celestial # only use celestial portion
92+ sky_coord = SkyCoord .from_pixel (pixels [:, 1 ], pixels [:, 0 ], wcs )
93+ return sky_coord
9694
9795 def convert_pixel_to (self , target , pixel = None ):
9896 """Converts pixel coordinates from this frame to `target` frame
@@ -130,12 +128,11 @@ def u_to_pixel(self, distance):
130128 float
131129 size in pixels
132130 """
133- assert u .get_physical_type (distance ) == "angle"
134-
135- # first computer the pixel size
136- pixel_size = get_pixel_size (self .wcs .celestial ) * 60 * 60 # in arcsec/pixel
137-
138- return distance .to (u .arcsec ).value / pixel_size
131+ if self .wcs is not None :
132+ pixel_size = get_pixel_size (self .wcs )
133+ return distance / pixel_size
134+ else :
135+ return distance
139136
140137 def pixel_to_angle (self , size ):
141138 """Converts pixel size to celestial distance according to this frame WCS
@@ -148,12 +145,10 @@ def pixel_to_angle(self, size):
148145 Returns
149146 -------
150147 distance: :py:class:`astropy.units.Quantity`
151- Physical size, must be `PhysicalType("angle")`
152148 """
153- # first computer the pixel size
154- pixel_size = get_pixel_size (self .wcs .celestial ) * 60 * 60 # in arcsec/pixel
155149
156- distance = size * pixel_size * u .arcsec
150+ pixel_size = get_pixel_size (self .wcs )
151+ distance = size * pixel_size
157152 return distance
158153
159154 @staticmethod
@@ -189,7 +184,6 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
189184 # Array of pixel sizes for each observation
190185 pix_tab = []
191186 # Array of psf size for each psf of each observation
192- fat_psf_size = None
193187 small_psf_size = None
194188 channels = []
195189 # Create frame channels and find smallest and largest psf
@@ -199,14 +193,14 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
199193
200194 # concatenate all pixel sizes
201195 h_temp = get_pixel_size (obs .frame .wcs )
202-
196+ if isinstance (h_temp , u .Quantity ):
197+ h_temp = h_temp .to (u .arcsec ).value # standardize pixel sizes, using simple scalars below
203198 pix_tab .append (h_temp )
204- # Looking for the sharpest and the fatest psf
199+
200+ # Looking for the sharpest PSF
205201 psf = obs .frame .psf .morphology
206202 for psf_channel in psf :
207203 psf_size = get_psf_size (psf_channel ) * h_temp
208- if (fat_psf_size is None ) or (psf_size > fat_psf_size ):
209- fat_psf_size = psf_size
210204 if (
211205 model_psf is None
212206 and ((obs_id is None ) or (c == obs_id ))
@@ -225,23 +219,22 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
225219
226220 # Reference wcs
227221 if model_wcs is None :
228- model_wcs = obs_ref .frame .wcs
222+ model_wcs = obs_ref .frame .wcs . deepcopy ()
229223
230- # Scale of the smallest pixel
224+ # Scale of the model pixel
231225 h = get_pixel_size (model_wcs )
232226
233227 # If needed and psf is not provided: interpolate psf to smallest pixel
234228 if model_psf is None :
235229 # create Gaussian PSF with a sigma smaller than the smallest observed PSF
236230 sigma = 0.7
237- assert (
238- small_psf_size / h > sigma
239- ), f"Default model PSF width ( { sigma } pixel) too large for best-seeing observation"
231+ assert small_psf_size / h > sigma , (
232+ f"Default model PSF width ( { sigma } pixel) too large for best-seeing observation"
233+ )
240234 model_psf = GaussianPSF (sigma = sigma )
241235
242236 # Dummy frame for WCS computations
243237 model_shape = (len (channels ), 0 , 0 )
244-
245238 model_frame = Frame (Box (model_shape ), channels = channels , psf = model_psf , wcs = model_wcs )
246239
247240 # Determine overlap of all observations in pixel coordinates of the model frame
@@ -262,11 +255,14 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
262255 else :
263256 model_box &= this_box
264257
258+ # update model_wcs to change NAXIS1/2 and CRPIX1/2, but don't change frame_origin!
259+ model_wcs ._naxis = list (model_wcs ._naxis )
260+ model_wcs ._naxis [:2 ] = model_box .shape [::- 1 ] # x/y needed here
261+ model_wcs .wcs .crpix [:2 ] -= model_box .origin [::- 1 ] # x/y needed here
262+
263+ # frame_origin = (0,) + model_box.origin
265264 frame_shape = (len (channels ),) + model_box .shape
266- frame_origin = (0 ,) + model_box .origin
267- model_frame = Frame (
268- Box (shape = frame_shape , origin = frame_origin ), channels = channels , psf = model_psf , wcs = model_wcs
269- )
265+ model_frame = Frame (Box (shape = frame_shape ), channels = channels , psf = model_psf , wcs = model_wcs )
270266
271267 # Match observations to this frame
272268 for obs in observations :
@@ -311,68 +307,145 @@ def get_psf_size(psf):
311307
312308def get_affine (wcs ):
313309 """Return the WCS transformation matrix"""
310+ if wcs is None :
311+ return jnp .diag (jnp .ones (2 ))
312+ wcs_ = wcs .celestial
314313 try :
315- model_affine = wcs .wcs .pc
314+ model_affine = wcs_ .wcs .pc
316315 except AttributeError :
317316 try :
318- model_affine = wcs .cd
317+ model_affine = wcs_ .cd
319318 except AttributeError :
320- model_affine = wcs .wcs .cd
321- return model_affine
319+ model_affine = wcs_ .wcs .cd
320+ return model_affine [: 2 , : 2 ]
322321
323322
324- def get_pixel_size (wcs ):
325- """Extracts the pixel size from a wcs, and returns it in deg/pixel"""
326- if wcs is None :
327- return 1
328- model_affine = get_affine (wcs )
329- pix = jnp .sqrt (
330- jnp .abs (model_affine [0 , 0 ]) * jnp .abs (model_affine [1 , 1 ] - model_affine [0 , 1 ] * model_affine [1 , 0 ])
331- )
332- return pix
323+ # for WCS linear matrix calculations:
324+ # rotation matrix for counter-clockwise rotation from positive x-axis
325+ # uses (x,y) coordinates and phi in radian!!
326+ def _rot_matrix (phi ):
327+ sinphi , cosphi = jnp .sin (phi ), jnp .cos (phi )
328+ return jnp .array ([[cosphi , - sinphi ], [sinphi , cosphi ]])
329+
330+
331+ # flip in y!!!
332+ # uses (x,y) coordinates!
333+ _flip_matrix = lambda flip : jnp .diag (jnp .array ((1.0 , flip )))
333334
335+ # 2x2 matrix determinant
336+ _det = lambda m : m [0 , 0 ] * m [1 , 1 ] - m [0 , 1 ] * m [1 , 0 ]
337+
338+
339+ def get_scale_angle_flip (trans ):
340+ """Return, scale, angle, flip from the WCS transformation matrix
341+
342+ Parameters
343+ ----------
344+ trans: (`astropy.wcs.WCS`, array)
345+ WCS or WCS transformation matrix
334346
335- def get_scale (wcs ):
347+ Returns
348+ -------
349+ scale: `float`
350+ angle: `float`, in radian
351+ flip: -1 or 1
336352 """
337- Return WCS axis scales in deg/pixel
353+ if isinstance (trans , (np .ndarray , jnp .ndarray )): # noqa: SIM108
354+ M = trans # noqa: N806
355+ else :
356+ M = get_affine (trans ) # noqa: N806
357+
358+ det = _det (M )
359+ # this requires pixels to be square
360+ # if not, use scale = jnp.linalg.svd(M, compute_uv=False)
361+ # but be careful with rotations as anisotropic stretch and rotation do not commute
362+ scale = jnp .sqrt (jnp .abs (det )).item (0 )
363+
364+ # if rotation is improper: need to apply y-flip to M to get pure rotation matrix (and unique angle)
365+ improper = det < 0
366+ flip = - 1 if improper else 1
367+ F = _flip_matrix (flip ) # noqa: N806, flip in y, is identity if flip = 1!!!
368+ M_ = F @ M # noqa: N806, flip = inverse flip
369+ angle = jnp .arctan2 (M_ [1 , 0 ], M_ [0 , 0 ]).item ()
370+
371+ return scale , angle , flip
372+
373+
374+ def get_pixel_size (wcs ):
375+ """Extracts the pixel size from a wcs, and returns it in deg/pixel
376+
377+ Parameters
378+ ----------
379+ wcs: `astropy.wcs.WCS`
380+ WCS structure or transformation matrix
381+
382+ Returns
383+ -------
384+ pixel_size: `float`
338385 """
339- if wcs is None :
340- return 1
341- model_affine = get_affine (wcs )
342- c1 = (model_affine [0 , :2 ] ** 2 ).sum () ** 0.5
343- c2 = (model_affine [1 , :2 ] ** 2 ).sum () ** 0.5
344- return jnp .array ([c1 , c2 ])
386+ scale , angle , flip = get_scale_angle_flip (wcs )
387+ return scale
345388
346389
347- def get_angle (wcs ):
390+ def get_scale (wcs , separate = False ):
348391 """
349- Return WCS rotation angle in rad
392+ Get WCS axis scales in deg/pixel
393+
394+ Parameters
395+ ----------
396+ wcs: `astropy.wcs.WCS`
397+ WCS structure or transformation matrix
398+ separate: `bool`
399+ Compute separate axis scales
400+
401+ Returns
402+ -------
403+ float
350404 """
351- if wcs is None :
352- return 0
353- model_affine = get_affine (wcs )
354- c = get_scale (wcs )
355- c = c .reshape ([c .shape [- 1 ], 1 ])
356- r = model_affine [:2 , :2 ] / c # removing the scaling factors from the pc
357-
358- if r [0 , 0 ] == 0.0 :
359- return jnp .arcsin (r [0 , 1 ])
405+ if separate :
406+ M = get_affine (wcs ) # noqa: N806
407+ c1 = (M [0 , :] ** 2 ).sum () ** 0.5
408+ c2 = (M [1 , :] ** 2 ).sum () ** 0.5
409+ return jnp .array ([c1 , c2 ])
360410 else :
361- return jnp .arctan (r [0 , 1 ] / r [0 , 0 ])
411+ scale , angle , flip = get_scale_angle_flip (wcs )
412+ return scale
362413
363414
364- def get_sign (wcs ):
415+ def get_angle (wcs ):
416+ """
417+ Get WCS rotation angle
418+
419+ The angle is computed counter-clockwise from the positive x-axis, in radians.
420+
421+ Parameters
422+ ----------
423+ wcs: `astropy.wcs.WCS`
424+ WCS structure or transformation matrix
425+
426+ Returns
427+ -------
428+ `astropy.units.quantity.Quantity`, unit = u.rad
365429 """
366- Return WCS flip signs
430+ scale , angle , flip = get_scale_angle_flip (wcs )
431+ return angle
432+
433+
434+ def get_flip (wcs ):
367435 """
368- model_affine = get_affine (wcs )
369- c = get_scale (wcs )
370- c = c .reshape ([c .shape [- 1 ], 1 ])
371- r = model_affine [:2 , :2 ] / c # removing the absolute scaling factors from the pc
436+ Return WCS sign convention
372437
373- phi = jnp .arcsin (r [0 , 1 ]) if r [0 , 0 ] == 0.0 else jnp .arctan (r [0 , 1 ] / r [0 , 0 ])
438+ A negative sign means that the rotation is improper and requires a flip.
439+ By convention, we define this to be a flip in the y-axis.
374440
375- r_inv = jnp .array ([[jnp .cos (phi ), - jnp .sin (phi )], [jnp .sin (phi ), jnp .cos (phi )]])
441+ Parameters
442+ ----------
443+ wcs: `astropy.wcs.WCS`
444+ WCS structure or transformation matrix
376445
377- r = r_inv @ r
378- return jnp .round (jnp .diag (r ))
446+ Returns
447+ -------
448+ -1 or 1
449+ """
450+ scale , angle , flip = get_scale_angle_flip (wcs )
451+ return flip
0 commit comments