Skip to content

Commit 921b810

Browse files
committed
- Introduced epoch_group to replace ephemeris_group for better clarity.
- Added `mask_nonfinite_errors` flag to handle masking invalid errors in data. - Enhanced transit parameters by separating orbital and epoch-related data. - Updated outlier removal logic to a dedicated `mask_outliers` method. - Improved initialization and handling of transit center parameters. - Deprecated `remove_outliers` in favor of `mask_outliers`. - Refactored `nanmean` and other NumPy operations for improved efficiency.
1 parent 637d78b commit 921b810

File tree

3 files changed

+154
-49
lines changed

3 files changed

+154
-49
lines changed

exoiris/tsdata.py

Lines changed: 73 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,9 @@
2929
from matplotlib.figure import Figure
3030
from matplotlib.pyplot import subplots, setp
3131
from matplotlib.ticker import LinearLocator, FuncFormatter
32-
from numpy import isfinite, median, where, all, zeros_like, diff, asarray, interp, arange, floor, ndarray, \
33-
ceil, newaxis, inf, array, ones, poly1d, polyfit, nanpercentile, atleast_2d, nan, linspace, any, sqrt, nanmedian
32+
from numpy import any, isfinite, median, where, all, zeros_like, diff, asarray, interp, arange, floor, ndarray, \
33+
ceil, newaxis, inf, array, ones, poly1d, polyfit, nanpercentile, atleast_2d, nan, linspace, any, sqrt, nanmedian, \
34+
nanmean
3435
from pytransit.orbits import fold
3536
from scipy.ndimage import median_filter
3637
from scipy.signal import medfilt
@@ -45,9 +46,10 @@ class TSData:
4546
fluxes, and errors. It provides methods for manipulating and analyzing the data.
4647
"""
4748
def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, errors: Sequence, name: str,
48-
noise_group: str = 'a', wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
49+
noise_group: int = 0, wl_edges : Sequence | None = None, tm_edges : Sequence | None = None,
4950
transit_mask: ndarray | None = None, ephemeris: Ephemeris | None = None, n_baseline: int = 1,
50-
mask: ndarray = None, ephemeris_group: int = 0, offset_group: int = 0) -> None:
51+
mask: ndarray = None, epoch_group: int = 0, offset_group: int = 0,
52+
mask_nonfinite_errors: bool = True) -> None:
5153
"""
5254
Parameters
5355
----------
@@ -81,8 +83,11 @@ def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, error
8183
if n_baseline < 1:
8284
raise ValueError("n_baseline must be greater than zero.")
8385

84-
if ephemeris_group < 0:
85-
raise ValueError("ephemeris_group must be a non-negative integer.")
86+
if noise_group < 0:
87+
raise ValueError("noise_group must be a positive integer.")
88+
89+
if epoch_group < 0:
90+
raise ValueError("epoch_group must be a non-negative integer.")
8691

8792
if offset_group < 0:
8893
raise ValueError("offset_group must be a non-negative integer.")
@@ -94,17 +99,20 @@ def __init__(self, time: Sequence, wavelength: Sequence, fluxes: Sequence, error
9499
raise ValueError("The wavelength array must contain only finite values.")
95100

96101
self.name: str = name
102+
self.mask_nonfinite_errors: bool = mask_nonfinite_errors
97103
self.time: ndarray = time.copy()
98104
self.wavelength: ndarray = wavelength
99-
self.mask: ndarray = mask if mask is not None else isfinite(fluxes) & isfinite(errors)
105+
self.mask: ndarray = mask if mask is not None else isfinite(fluxes)
106+
if self.mask_nonfinite_errors:
107+
self.mask &= isfinite(errors)
100108
self.fluxes: ndarray = where(self.mask, fluxes, nan)
101109
self.errors: ndarray = where(self.mask, errors, nan)
102110
self.transit_mask: ndarray = transit_mask if transit_mask is not None else ones(time.size, dtype=bool)
103111
self.ngid: int = 0
104-
self.ephemeris: Ephemeris | None = ephemeris
112+
self._ephemeris: Ephemeris | None = ephemeris
105113
self.n_baseline: int = n_baseline
106114
self._noise_group: str = noise_group
107-
self.ephemeris_group: int = ephemeris_group
115+
self.epoch_group: int = epoch_group
108116
self.offset_group: int = offset_group
109117
self._dataset: Optional['TSDataSet'] = None
110118
self._update()
@@ -143,7 +151,7 @@ def export_fits(self) -> pf.HDUList:
143151
mask = pf.ImageHDU(self.mask.astype(int), name=f'mask_{self.name}')
144152
data.header['ngroup'] = self.noise_group
145153
data.header['nbasel'] = self.n_baseline
146-
data.header['epgroup'] = self.ephemeris_group
154+
data.header['epgroup'] = self.epoch_group
147155
data.header['offgroup'] = self.offset_group
148156
#TODO: export ephemeris
149157
return pf.HDUList([time, wave, data, ootm, mask])
@@ -191,7 +199,7 @@ def import_fits(name: str, hdul: pf.HDUList) -> 'TSData':
191199

192200
#TODO: import ephemeris
193201
return TSData(time, wave, data[0], data[1], name=name, noise_group=noise_group, transit_mask=ootm,
194-
n_baseline=n_baseline, mask=mask, ephemeris_group=ephemeris_group, offset_group=offset_group)
202+
n_baseline=n_baseline, mask=mask, epoch_group=ephemeris_group, offset_group=offset_group)
195203

196204
def __repr__(self) -> str:
197205
return f"TSData Name:'{self.name}' [{self.wavelength[0]:.2f} - {self.wavelength[-1]:.2f}] nwl={self.nwl} npt={self.npt}"
@@ -207,6 +215,16 @@ def noise_group(self, ng: str) -> None:
207215
if self._dataset is not None:
208216
self._dataset._update_nids()
209217

218+
@property
219+
def ephemeris(self) -> Ephemeris:
220+
"""Ephemeris."""
221+
return self._ephemeris
222+
223+
@ephemeris.setter
224+
def ephemeris(self, ep: Ephemeris) -> None:
225+
self._ephemeris = ep
226+
self.mask_transit(ephemeris=ep)
227+
210228
def mask_transit(self, t0: float | None = None, p: float | None = None, t14: float | None = None,
211229
ephemeris : Ephemeris | None = None, elims: tuple[int, int] | None = None) -> 'TSData':
212230
"""Create a transit mask based on a given ephemeris or exposure index limits.
@@ -226,9 +244,9 @@ def mask_transit(self, t0: float | None = None, p: float | None = None, t14: flo
226244
"""
227245
if (t0 and p and t14) or ephemeris is not None:
228246
if ephemeris is not None:
229-
self.ephemeris = ephemeris
247+
self._ephemeris = ephemeris
230248
else:
231-
self.ephemeris = Ephemeris(t0, p, t14)
249+
self._ephemeris = Ephemeris(t0, p, t14)
232250
phase = fold(self.time, self.ephemeris.period, self.ephemeris.zero_epoch)
233251
self.transit_mask = abs(phase) > 0.502 * self.ephemeris.duration
234252
elif elims is not None:
@@ -257,6 +275,15 @@ def _update(self) -> None:
257275
self.nwl = self.wavelength.size
258276
self.npt = self.time.size
259277
self.wllims = self.wavelength.min(), self.wavelength.max()
278+
if self._ephemeris is not None:
279+
self.mask_transit(ephemeris=self._ephemeris)
280+
281+
def _update_data_mask(self) -> None:
282+
self.mask = isfinite(self.fluxes)
283+
if self.mask_nonfinite_errors:
284+
self.mask &= isfinite(self.errors)
285+
self.fluxes = where(self.mask, self.fluxes, nan)
286+
self.errors = where(self.mask, self.errors, nan)
260287

261288
def normalize_to_poly(self, deg: int = 1) -> 'TSData':
262289
"""Normalize the baseline flux for each spectroscopic light curve.
@@ -289,6 +316,7 @@ def normalize_to_poly(self, deg: int = 1) -> 'TSData':
289316
deg=deg))(self.time)
290317
self.fluxes[ipb, :] /= bl
291318
self.errors[ipb, :] /= bl
319+
self._update_data_mask()
292320
return self
293321

294322
def normalize_to_median(self, s: slice) -> 'TSData':
@@ -317,20 +345,22 @@ def partition_time(self, tlims: tuple[tuple[float,float]]) -> 'TSDataSet':
317345
d = TSData(name=f'{self.name}_1', time=self.time[m], wavelength=self.wavelength,
318346
fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
319347
noise_group=self.noise_group,
320-
ephemeris_group=self.ephemeris_group,
348+
epoch_group=self.epoch_group,
321349
offset_group=self.offset_group,
322350
transit_mask=self.transit_mask[m],
323351
ephemeris=self.ephemeris,
324-
n_baseline=self.n_baseline)
352+
n_baseline=self.n_baseline,
353+
mask_nonfinite_errors=self.mask_nonfinite_errors)
325354
for i, m in enumerate(masks[1:]):
326355
d = d + TSData(name=f'{self.name}_{i+2}', time=self.time[m], wavelength=self.wavelength,
327356
fluxes=self.fluxes[:, m], errors=self.errors[:, m], mask=self.mask[:, m],
328357
noise_group=self.noise_group,
329-
ephemeris_group=self.ephemeris_group,
358+
epoch_group=self.epoch_group,
330359
offset_group=self.offset_group,
331360
transit_mask=self.transit_mask[m],
332361
ephemeris=self.ephemeris,
333-
n_baseline=self.n_baseline)
362+
n_baseline=self.n_baseline,
363+
mask_nonfinite_errors=self.mask_nonfinite_errors)
334364
return d
335365

336366
def crop_wavelength(self, lmin: float, lmax: float, inplace: bool = True) -> 'TSData':
@@ -362,12 +392,13 @@ def crop_wavelength(self, lmin: float, lmax: float, inplace: bool = True) -> 'TS
362392
errors=self.errors[m],
363393
mask=self.mask[m],
364394
noise_group=self.noise_group,
365-
ephemeris_group=self.ephemeris_group,
395+
epoch_group=self.epoch_group,
366396
offset_group=self.offset_group,
367397
wl_edges=(self._wl_l_edges[m], self._wl_r_edges[m]),
368398
tm_edges=(self._tm_l_edges, self._tm_r_edges),
369399
transit_mask=self.transit_mask, ephemeris=self.ephemeris,
370-
n_baseline=self.n_baseline)
400+
n_baseline=self.n_baseline,
401+
mask_nonfinite_errors=self.mask_nonfinite_errors)
371402

372403
def crop_time(self, tmin: float, tmax: float, inplace: bool = True) -> 'TSData':
373404
"""Crop the data to include only the time range between lmin and lmax.
@@ -399,19 +430,20 @@ def crop_time(self, tmin: float, tmax: float, inplace: bool = True) -> 'TSData':
399430
errors=self.errors[:, m],
400431
mask = self.mask[:, m],
401432
noise_group=self.noise_group,
402-
ephemeris_group=self.ephemeris_group,
433+
epoch_group=self.epoch_group,
403434
offset_group=self.offset_group,
404435
wl_edges=(self._wl_l_edges, self._wl_r_edges),
405436
tm_edges=(self._tm_l_edges[m], self._tm_r_edges[m]),
406437
transit_mask=self.transit_mask[m], ephemeris=self.ephemeris,
407-
n_baseline=self.n_baseline)
438+
n_baseline=self.n_baseline,
439+
mask_nonfinite_errors=self.mask_nonfinite_errors)
408440

409-
def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
410-
"""Remove outliers along the wavelength axis.
441+
# TODO: separate mask into bad data mask and outlier mask.
442+
def mask_outliers(self, sigma: float = 5.0) -> 'TSData':
443+
"""Mask outliers along the wavelength axis.
411444
412-
Replace outliers along the wavelength axis with the value of a 5-point running median filter. Outliers are
413-
defined as data points that deviate from the median by more than sigma times the median absolute deviation
414-
along the wavelength axis.
445+
Outliers are defined as data points that deviate from the running 5-point median by more
446+
than sigma times the median absolute deviation along the wavelength axis.
415447
416448
Parameters
417449
----------
@@ -422,13 +454,18 @@ def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
422454
----
423455
The data will be modified in place.
424456
"""
425-
fm = median(self.fluxes, axis=0)
426-
fe = mad_std(self.fluxes, axis=0)
457+
fm = nanmedian(self.fluxes, axis=0)
458+
fe = mad_std(self.fluxes, axis=0, ignore_nan=True)
427459
self.mask &= abs(self.fluxes - fm) / fe < sigma
428460
self.fluxes = where(self.mask, self.fluxes, nan)
429461
self.errors = where(self.mask, self.errors, nan)
430462
return self
431463

464+
@deprecated("0.10", alternative="TSData.mask_outliers")
465+
def remove_outliers(self, sigma: float = 5.0) -> 'TSData':
466+
"""Remove outliers along the wavelength axis."""
467+
self.mask_outliers(sigma=sigma)
468+
432469
def plot(self, ax=None, vmin: float = None, vmax: float = None, cmap=None, figsize=None, data=None,
433470
plims: tuple[float, float] | None = None) -> Figure:
434471
"""Plot the spectroscopic light curves as a 2D image.
@@ -528,7 +565,7 @@ def plot_white(self, ax: Axes | None = None, figsize: tuple[float, float] | None
528565
fig = ax.figure
529566
tref = floor(self.time.min())
530567

531-
ax.plot(self.time, self.fluxes.mean(0))
568+
ax.plot(self.time, nanmean(self.fluxes, 0))
532569
if self.ephemeris is not None:
533570
[ax.axvline(tl, ls='--', c='k') for tl in self.ephemeris.transit_limits(self.time.mean())]
534571

@@ -620,7 +657,7 @@ def bin_wavelength(self, binning: Optional[Union[Binning, CompoundBinning]] = No
620657
name=self.name,
621658
tm_edges=(self._tm_l_edges, self._tm_r_edges),
622659
noise_group=self.noise_group,
623-
ephemeris_group=self.ephemeris_group,
660+
epoch_group=self.epoch_group,
624661
offset_group=self.offset_group,
625662
transit_mask=self.transit_mask,
626663
ephemeris=self.ephemeris,
@@ -662,7 +699,7 @@ def bin_time(self, binning: Optional[Union[Binning, CompoundBinning]] = None,
662699
noise_group=self.noise_group,
663700
ephemeris=self.ephemeris,
664701
n_baseline=self.n_baseline,
665-
ephemeris_group=self.ephemeris_group,
702+
epoch_group=self.epoch_group,
666703
offset_group=self.offset_group)
667704
if self.ephemeris is not None:
668705
d.mask_transit(ephemeris=self.ephemeris)
@@ -739,6 +776,11 @@ def offset_groups(self) -> list[int]:
739776
"""List of offset groups."""
740777
return [d.offset_group for d in self.data]
741778

779+
@property
780+
def epoch_groups(self) -> list[int]:
781+
"""List of epoch groups."""
782+
return [d.epoch_group for d in self.data]
783+
742784
@property
743785
def n_baselines(self) -> list[int]:
744786
"""Number of baseline coefficients for each data set."""

exoiris/tslpf.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ def _init_parameters(self) -> None:
203203
self.ps = ParameterSet([])
204204
self._init_p_star()
205205
self._init_p_orbit()
206+
self._init_p_transit_centers()
206207
self._init_p_limb_darkening()
207208
self._init_p_radius_ratios()
208209
self._init_p_noise()
@@ -311,15 +312,22 @@ def _init_p_limb_darkening(self) -> None:
311312

312313
def _init_p_orbit(self):
313314
ps = self.ps
314-
pp = [GParameter('tc', 'zero_epoch', '', NP(0.0, 0.1), (-inf, inf)),
315-
GParameter('p', 'period', 'd', NP(1.0, 1e-5), (0, inf)),
315+
pp = [GParameter('p', 'period', 'd', NP(1.0, 1e-5), (0, inf)),
316316
GParameter('b', 'impact_parameter', 'R_s', UP(0.0, 1.0), (0, inf)),
317317
GParameter('secw', 'sqrt(e) cos(w)', '', NP(0.0, 1e-5), (-1, 1)),
318318
GParameter('sesw', 'sqrt(e) sin(w)', '', NP(0.0, 1e-5), (-1, 1))]
319319
ps.add_global_block('orbit', pp)
320320
self._start_orbit = ps.blocks[-1].start
321321
self._sl_orbit = ps.blocks[-1].slice
322322

323+
def _init_p_transit_centers(self):
324+
ps = self.ps
325+
neps = max(self.data.epoch_groups) + 1
326+
pp = [GParameter(f'tc_{i:02d}', f'zero_epoch_{i:02d}', '', NP(0.0, 0.1), (-inf, inf)) for i in range(neps)]
327+
ps.add_global_block('transit_centers', pp)
328+
self._start_tcs = ps.blocks[-1].start
329+
self._sl_tcs = ps.blocks[-1].slice
330+
323331
def _init_p_radius_ratios(self):
324332
ps = self.ps
325333
pp = [GParameter(f'k_{k:08.5f}', fr'radius ratio at {k:08.5f} $\mu$m', 'A_s', UP(0.02, 0.2), (0, inf)) for k in self.k_knots]
@@ -592,22 +600,24 @@ def transit_model(self, pv, copy=True):
592600
"""
593601
pv = atleast_2d(pv)
594602
ldp = self._eval_ldc(pv)
595-
t0 = pv[:, 1]
596-
p = pv[:, 2]
603+
t0s = pv[:, self._sl_tcs]
597604
k = self._eval_k(pv[:, self._sl_rratios])
605+
p = pv[:, 1]
598606
aor = as_from_rhop(pv[:, 0], p)
599-
inc = i_from_ba(pv[:, 3], aor)
600-
ecc = pv[:, 4] ** 2 + pv[:, 5] ** 2
601-
w = arctan2(pv[:, 5], pv[:, 4])
607+
inc = i_from_ba(pv[:, 2], aor)
608+
ecc = pv[:, 3] ** 2 + pv[:, 4] ** 2
609+
w = arctan2(pv[:, 4], pv[:, 3])
610+
epids = self.data.epoch_groups
602611
fluxes = []
603612
if isinstance(self.ldmodel, LDTkLD):
604613
ldp, istar = self.ldmodel(self.tms[0].mu, ldp)
605614
ldpi = dstack([ldp, istar])
606615
for i, tm in enumerate(self.tms):
607-
fluxes.append(tm.evaluate(k[i], ldpi[:, self.ldmodel.wlslices[i], :], t0, p, aor, inc, ecc, w, copy))
616+
fluxes.append(tm.evaluate(k[i], ldpi[:, self.ldmodel.wlslices[i], :],
617+
t0s[:, epids[i]], p, aor, inc, ecc, w, copy))
608618
else:
609619
for i, tm in enumerate(self.tms):
610-
fluxes.append(tm.evaluate(k[i], ldp[i], t0, p, aor, inc, ecc, w, copy))
620+
fluxes.append(tm.evaluate(k[i], ldp[i], t0s[:, epids[i]], p, aor, inc, ecc, w, copy))
611621

612622
for i, d in enumerate(self.data):
613623
if d.offset_group > 0:

0 commit comments

Comments
 (0)