Skip to content

Commit 6891e7e

Browse files
authored
Correct rotations even when improper (#225)
* correct treatment of WCS angle, flip, scale (fixes #218); fixed tests of renderer and moments * make Moments fluent * angle and scale get astropy units * fixed pixels size comparison * removed units from angle and scale because not supported by jax; one could use unxt, but with several new dependencies * moved set_spectra_to_match to measure.forced_photometry * removed set_spectra_to_match * corrected plotting centers and boxes for WCS * create dummy WCS if not set * minor fixes * added check for sub-pixel shift * added moment normalization
1 parent a7646a8 commit 6891e7e

File tree

15 files changed

+511
-527
lines changed

15 files changed

+511
-527
lines changed

docs/0-quickstart.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,6 @@
340340
"outputs": [],
341341
"source": [
342342
"maxiter = 1000\n",
343-
"scene.set_spectra_to_match(obs, parameters)\n",
344343
"scene_ = scene.fit(obs, parameters, max_iter=maxiter, e_rel=1e-4, progress_bar=True)"
345344
]
346345
},

docs/howto/multiresolution.ipynb

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@
296296
"id": "81bb89a6412b8ce8",
297297
"metadata": {},
298298
"source": [
299-
" But the initial linear solver for the spectrum amplitudes and the fitting method receive lists of observations now:"
299+
" The fitting method now receives a list of observations:"
300300
]
301301
},
302302
{
@@ -308,7 +308,6 @@
308308
},
309309
"outputs": [],
310310
"source": [
311-
"scene.set_spectra_to_match([obs_hsc, obs_hst], parameters)\n",
312311
"scene_ = scene.fit([obs_hsc, obs_hst], parameters, max_iter=50, progress_bar=True)"
313312
]
314313
},

docs/howto/priors.ipynb

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,6 @@
197197
"source": [
198198
"maxiter = 1000\n",
199199
"print(\"Initial likelihood:\", obs.log_likelihood(scene()))\n",
200-
"scene.set_spectra_to_match(obs, parameters)\n",
201200
"scene_ = scene.fit(obs, parameters, max_iter=maxiter, e_rel=1e-4, progress_bar=True)\n",
202201
"print(\"Optimized likelihood:\", obs.log_likelihood(scene_()))"
203202
]

src/scarlet2/frame.py

Lines changed: 157 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import astropy.wcs
33
import equinox as eqx
44
import jax.numpy as jnp
5+
import numpy as np
56
from astropy.coordinates import SkyCoord
67

78
from .bbox import Box
@@ -27,10 +28,14 @@ class Frame(eqx.Module):
2728
def __init__(self, bbox, psf=None, wcs=None, channels=None):
2829
self.bbox = bbox
2930
self.psf = psf
31+
if wcs is None:
32+
wcs = astropy.wcs.WCS(naxis=2) # Dummy WCS
33+
wcs._naxis = bbox.spatial.shape[::-1]
34+
wcs.wcs.ctype = ["RA---TAN", "DEC--TAN"]
3035
self.wcs = wcs
36+
3137
if channels is None:
3238
channels = list(range(bbox.shape[0]))
33-
3439
self.channels = channels
3540

3641
def __hash__(self):
@@ -47,14 +52,10 @@ def pixel_size(self):
4752
4853
Returns
4954
-------
50-
float
51-
Pixel size in arcsec, averaged over x and y direction
55+
float, astropy.units.quantity.Quantity
56+
Pixel size in units of the WCS sky coordinates
5257
"""
53-
if self.wcs is not None:
54-
# return get_pixel_size(get_affine(self.wcs)) * 60 * 60 # in arcsec
55-
return get_scale(self.wcs).mean() * 60 * 60 # in arcsec
56-
else:
57-
return 1
58+
return get_scale(self.wcs)
5859

5960
def get_pixel(self, pos):
6061
"""Get the sky coordinates from a world coordinate
@@ -69,7 +70,6 @@ def get_pixel(self, pos):
6970
pixel coordinates in the model frame
7071
"""
7172
if isinstance(pos, SkyCoord):
72-
assert self.wcs is not None, "SkyCoord can only be converted with valid WCS"
7373
wcs_ = self.wcs.celestial # only use celestial portion
7474
pixel = jnp.asarray(pos.to_pixel(wcs_), dtype="float32").T
7575
return pixel[..., ::-1]
@@ -85,14 +85,12 @@ def get_sky_coord(self, pos):
8585
8686
Returns
8787
----------
88-
astropy.coordinates.SkyCoord if WCS is set, otherwise pos
88+
astropy.coordinates.SkyCoord
8989
"""
90-
if self.wcs is not None:
91-
pixels = pos.reshape(-1, 2)
92-
wcs = self.wcs.celestial # only use celestial portion
93-
sky_coord = SkyCoord.from_pixel(pixels[:, 1], pixels[:, 0], wcs)
94-
return sky_coord
95-
return pos
90+
pixels = pos.reshape(-1, 2)
91+
wcs = self.wcs.celestial # only use celestial portion
92+
sky_coord = SkyCoord.from_pixel(pixels[:, 1], pixels[:, 0], wcs)
93+
return sky_coord
9694

9795
def convert_pixel_to(self, target, pixel=None):
9896
"""Converts pixel coordinates from this frame to `target` frame
@@ -130,12 +128,11 @@ def u_to_pixel(self, distance):
130128
float
131129
size in pixels
132130
"""
133-
assert u.get_physical_type(distance) == "angle"
134-
135-
# first computer the pixel size
136-
pixel_size = get_pixel_size(self.wcs.celestial) * 60 * 60 # in arcsec/pixel
137-
138-
return distance.to(u.arcsec).value / pixel_size
131+
if self.wcs is not None:
132+
pixel_size = get_pixel_size(self.wcs)
133+
return distance / pixel_size
134+
else:
135+
return distance
139136

140137
def pixel_to_angle(self, size):
141138
"""Converts pixel size to celestial distance according to this frame WCS
@@ -148,12 +145,10 @@ def pixel_to_angle(self, size):
148145
Returns
149146
-------
150147
distance: :py:class:`astropy.units.Quantity`
151-
Physical size, must be `PhysicalType("angle")`
152148
"""
153-
# first computer the pixel size
154-
pixel_size = get_pixel_size(self.wcs.celestial) * 60 * 60 # in arcsec/pixel
155149

156-
distance = size * pixel_size * u.arcsec
150+
pixel_size = get_pixel_size(self.wcs)
151+
distance = size * pixel_size
157152
return distance
158153

159154
@staticmethod
@@ -189,7 +184,6 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
189184
# Array of pixel sizes for each observation
190185
pix_tab = []
191186
# Array of psf size for each psf of each observation
192-
fat_psf_size = None
193187
small_psf_size = None
194188
channels = []
195189
# Create frame channels and find smallest and largest psf
@@ -199,14 +193,14 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
199193

200194
# concatenate all pixel sizes
201195
h_temp = get_pixel_size(obs.frame.wcs)
202-
196+
if isinstance(h_temp, u.Quantity):
197+
h_temp = h_temp.to(u.arcsec).value # standardize pixel sizes, using simple scalars below
203198
pix_tab.append(h_temp)
204-
# Looking for the sharpest and the fatest psf
199+
200+
# Looking for the sharpest PSF
205201
psf = obs.frame.psf.morphology
206202
for psf_channel in psf:
207203
psf_size = get_psf_size(psf_channel) * h_temp
208-
if (fat_psf_size is None) or (psf_size > fat_psf_size):
209-
fat_psf_size = psf_size
210204
if (
211205
model_psf is None
212206
and ((obs_id is None) or (c == obs_id))
@@ -225,23 +219,22 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
225219

226220
# Reference wcs
227221
if model_wcs is None:
228-
model_wcs = obs_ref.frame.wcs
222+
model_wcs = obs_ref.frame.wcs.deepcopy()
229223

230-
# Scale of the smallest pixel
224+
# Scale of the model pixel
231225
h = get_pixel_size(model_wcs)
232226

233227
# If needed and psf is not provided: interpolate psf to smallest pixel
234228
if model_psf is None:
235229
# create Gaussian PSF with a sigma smaller than the smallest observed PSF
236230
sigma = 0.7
237-
assert (
238-
small_psf_size / h > sigma
239-
), f"Default model PSF width ({sigma} pixel) too large for best-seeing observation"
231+
assert small_psf_size / h > sigma, (
232+
f"Default model PSF width ({sigma} pixel) too large for best-seeing observation"
233+
)
240234
model_psf = GaussianPSF(sigma=sigma)
241235

242236
# Dummy frame for WCS computations
243237
model_shape = (len(channels), 0, 0)
244-
245238
model_frame = Frame(Box(model_shape), channels=channels, psf=model_psf, wcs=model_wcs)
246239

247240
# Determine overlap of all observations in pixel coordinates of the model frame
@@ -262,11 +255,14 @@ def from_observations(observations, model_psf=None, model_wcs=None, obs_id=None,
262255
else:
263256
model_box &= this_box
264257

258+
# update model_wcs to change NAXIS1/2 and CRPIX1/2, but don't change frame_origin!
259+
model_wcs._naxis = list(model_wcs._naxis)
260+
model_wcs._naxis[:2] = model_box.shape[::-1] # x/y needed here
261+
model_wcs.wcs.crpix[:2] -= model_box.origin[::-1] # x/y needed here
262+
263+
# frame_origin = (0,) + model_box.origin
265264
frame_shape = (len(channels),) + model_box.shape
266-
frame_origin = (0,) + model_box.origin
267-
model_frame = Frame(
268-
Box(shape=frame_shape, origin=frame_origin), channels=channels, psf=model_psf, wcs=model_wcs
269-
)
265+
model_frame = Frame(Box(shape=frame_shape), channels=channels, psf=model_psf, wcs=model_wcs)
270266

271267
# Match observations to this frame
272268
for obs in observations:
@@ -311,68 +307,145 @@ def get_psf_size(psf):
311307

312308
def get_affine(wcs):
313309
"""Return the WCS transformation matrix"""
310+
if wcs is None:
311+
return jnp.diag(jnp.ones(2))
312+
wcs_ = wcs.celestial
314313
try:
315-
model_affine = wcs.wcs.pc
314+
model_affine = wcs_.wcs.pc
316315
except AttributeError:
317316
try:
318-
model_affine = wcs.cd
317+
model_affine = wcs_.cd
319318
except AttributeError:
320-
model_affine = wcs.wcs.cd
321-
return model_affine
319+
model_affine = wcs_.wcs.cd
320+
return model_affine[:2, :2]
322321

323322

324-
def get_pixel_size(wcs):
325-
"""Extracts the pixel size from a wcs, and returns it in deg/pixel"""
326-
if wcs is None:
327-
return 1
328-
model_affine = get_affine(wcs)
329-
pix = jnp.sqrt(
330-
jnp.abs(model_affine[0, 0]) * jnp.abs(model_affine[1, 1] - model_affine[0, 1] * model_affine[1, 0])
331-
)
332-
return pix
323+
# for WCS linear matrix calculations:
324+
# rotation matrix for counter-clockwise rotation from positive x-axis
325+
# uses (x,y) coordinates and phi in radian!!
326+
def _rot_matrix(phi):
327+
sinphi, cosphi = jnp.sin(phi), jnp.cos(phi)
328+
return jnp.array([[cosphi, -sinphi], [sinphi, cosphi]])
329+
330+
331+
# flip in y!!!
332+
# uses (x,y) coordinates!
333+
_flip_matrix = lambda flip: jnp.diag(jnp.array((1.0, flip)))
333334

335+
# 2x2 matrix determinant
336+
_det = lambda m: m[0, 0] * m[1, 1] - m[0, 1] * m[1, 0]
337+
338+
339+
def get_scale_angle_flip(trans):
340+
"""Return, scale, angle, flip from the WCS transformation matrix
341+
342+
Parameters
343+
----------
344+
trans: (`astropy.wcs.WCS`, array)
345+
WCS or WCS transformation matrix
334346
335-
def get_scale(wcs):
347+
Returns
348+
-------
349+
scale: `float`
350+
angle: `float`, in radian
351+
flip: -1 or 1
336352
"""
337-
Return WCS axis scales in deg/pixel
353+
if isinstance(trans, (np.ndarray, jnp.ndarray)): # noqa: SIM108
354+
M = trans # noqa: N806
355+
else:
356+
M = get_affine(trans) # noqa: N806
357+
358+
det = _det(M)
359+
# this requires pixels to be square
360+
# if not, use scale = jnp.linalg.svd(M, compute_uv=False)
361+
# but be careful with rotations as anisotropic stretch and rotation do not commute
362+
scale = jnp.sqrt(jnp.abs(det)).item(0)
363+
364+
# if rotation is improper: need to apply y-flip to M to get pure rotation matrix (and unique angle)
365+
improper = det < 0
366+
flip = -1 if improper else 1
367+
F = _flip_matrix(flip) # noqa: N806, flip in y, is identity if flip = 1!!!
368+
M_ = F @ M # noqa: N806, flip = inverse flip
369+
angle = jnp.arctan2(M_[1, 0], M_[0, 0]).item()
370+
371+
return scale, angle, flip
372+
373+
374+
def get_pixel_size(wcs):
375+
"""Extracts the pixel size from a wcs, and returns it in deg/pixel
376+
377+
Parameters
378+
----------
379+
wcs: `astropy.wcs.WCS`
380+
WCS structure or transformation matrix
381+
382+
Returns
383+
-------
384+
pixel_size: `float`
338385
"""
339-
if wcs is None:
340-
return 1
341-
model_affine = get_affine(wcs)
342-
c1 = (model_affine[0, :2] ** 2).sum() ** 0.5
343-
c2 = (model_affine[1, :2] ** 2).sum() ** 0.5
344-
return jnp.array([c1, c2])
386+
scale, angle, flip = get_scale_angle_flip(wcs)
387+
return scale
345388

346389

347-
def get_angle(wcs):
390+
def get_scale(wcs, separate=False):
348391
"""
349-
Return WCS rotation angle in rad
392+
Get WCS axis scales in deg/pixel
393+
394+
Parameters
395+
----------
396+
wcs: `astropy.wcs.WCS`
397+
WCS structure or transformation matrix
398+
separate: `bool`
399+
Compute separate axis scales
400+
401+
Returns
402+
-------
403+
float
350404
"""
351-
if wcs is None:
352-
return 0
353-
model_affine = get_affine(wcs)
354-
c = get_scale(wcs)
355-
c = c.reshape([c.shape[-1], 1])
356-
r = model_affine[:2, :2] / c # removing the scaling factors from the pc
357-
358-
if r[0, 0] == 0.0:
359-
return jnp.arcsin(r[0, 1])
405+
if separate:
406+
M = get_affine(wcs) # noqa: N806
407+
c1 = (M[0, :] ** 2).sum() ** 0.5
408+
c2 = (M[1, :] ** 2).sum() ** 0.5
409+
return jnp.array([c1, c2])
360410
else:
361-
return jnp.arctan(r[0, 1] / r[0, 0])
411+
scale, angle, flip = get_scale_angle_flip(wcs)
412+
return scale
362413

363414

364-
def get_sign(wcs):
415+
def get_angle(wcs):
416+
"""
417+
Get WCS rotation angle
418+
419+
The angle is computed counter-clockwise from the positive x-axis, in radians.
420+
421+
Parameters
422+
----------
423+
wcs: `astropy.wcs.WCS`
424+
WCS structure or transformation matrix
425+
426+
Returns
427+
-------
428+
`astropy.units.quantity.Quantity`, unit = u.rad
365429
"""
366-
Return WCS flip signs
430+
scale, angle, flip = get_scale_angle_flip(wcs)
431+
return angle
432+
433+
434+
def get_flip(wcs):
367435
"""
368-
model_affine = get_affine(wcs)
369-
c = get_scale(wcs)
370-
c = c.reshape([c.shape[-1], 1])
371-
r = model_affine[:2, :2] / c # removing the absolute scaling factors from the pc
436+
Return WCS sign convention
372437
373-
phi = jnp.arcsin(r[0, 1]) if r[0, 0] == 0.0 else jnp.arctan(r[0, 1] / r[0, 0])
438+
A negative sign means that the rotation is improper and requires a flip.
439+
By convention, we define this to be a flip in the y-axis.
374440
375-
r_inv = jnp.array([[jnp.cos(phi), -jnp.sin(phi)], [jnp.sin(phi), jnp.cos(phi)]])
441+
Parameters
442+
----------
443+
wcs: `astropy.wcs.WCS`
444+
WCS structure or transformation matrix
376445
377-
r = r_inv @ r
378-
return jnp.round(jnp.diag(r))
446+
Returns
447+
-------
448+
-1 or 1
449+
"""
450+
scale, angle, flip = get_scale_angle_flip(wcs)
451+
return flip

0 commit comments

Comments
 (0)