Skip to content

Commit 637d78b

Browse files
committed
- Added checks for finite values in fluxes and errors within dataset masks.
- Enforced consecutive numbering for noise, offset, and epoch groups. - Corrected transit parameter assignments (`period`, `zero_epoch`) and masking logic in `optimize_transit`. - Adjusted sampling of transit center and radius ratios to ensure valid ranges. - Incorporated proper handling of epoch group parameters for multi-epoch data.
1 parent f6d8872 commit 637d78b

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

exoiris/exoiris.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@
3232
from celerite2 import GaussianProcess, terms
3333
from emcee import EnsembleSampler
3434
from matplotlib.pyplot import subplots, setp, figure, Figure, Axes
35-
from numpy import (where, sqrt, clip, percentile, median, squeeze, floor, ndarray,
35+
from numpy import (any, where, sqrt, clip, percentile, median, squeeze, floor, ndarray, isfinite,
3636
array, inf, newaxis, arange, tile, sort, argsort, concatenate, full, nan, r_, nanpercentile, log10,
37-
ceil)
37+
ceil, unique)
3838
from numpy.random import normal, permutation
3939
from pytransit import UniformPrior, NormalPrior
4040
from pytransit.orbits import epoch
@@ -144,6 +144,26 @@ def __init__(self, name: str, ldmodel, data: TSDataSet | TSData, nk: int = 50, n
144144
The noise model to use. Should be either "white" for white noise or "fixed_gp" for Gaussian Process.
145145
"""
146146
data = TSDataSet([data]) if isinstance(data, TSData) else data
147+
148+
for d in data:
149+
if any(~isfinite(d.fluxes[d.mask])):
150+
raise ValueError(f"The {d.name} data set flux array contains unmasked noninfinite values.")
151+
152+
if any(~isfinite(d.errors[d.mask])):
153+
raise ValueError(f"The {d.name} data set error array contains unmasked noninfinite values.")
154+
155+
ngs = array(data.noise_groups)
156+
if not ((ngs.min() == 0) and (ngs.max() + 1 == unique(ngs).size)):
157+
raise ValueError("The noise groups must start from 0 and be consecutive.")
158+
159+
ogs = array(data.offset_groups)
160+
if not ((ogs.min() == 0) and (ogs.max() + 1 == unique(ogs).size)):
161+
raise ValueError("The offset groups must start from 0 and be consecutive.")
162+
163+
egs = array(data.epoch_groups)
164+
if not ((egs.min() == 0) and (egs.max() + 1 == unique(egs).size)):
165+
raise ValueError("The epoch groups must start from 0 and be consecutive.")
166+
147167
self._tsa: TSLPF = TSLPF(name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
148168
noise_model=noise_model, interpolation=interpolation)
149169
self._wa: WhiteLPF | None = None
@@ -542,10 +562,10 @@ def fit_white(self, niter: int = 500) -> None:
542562
self._wa.optimize_global(niter, plot_convergence=False, use_tqdm=False)
543563
self._wa.optimize()
544564
pv = self._wa._local_minimization.x
545-
self.period = pv[1]
565+
self.period = pv[0]
546566
self.zero_epoch = self._wa.transit_center
547567
self.transit_duration = self._wa.transit_duration
548-
self.data.mask_transit(pv[0], pv[1], self.transit_duration)
568+
self.data.mask_transit(self.zero_epoch, self.period, self.transit_duration)
549569

550570
def plot_white(self, axs=None, figsize: tuple[float, float] | None = None, ncols: int | None=None) -> Figure:
551571
"""Plot the white light curve data with the best-fit model.
@@ -680,12 +700,18 @@ def fit(self, niter: int = 200, npop: Optional[int] = None, pool: Optional[Pool]
680700

681701
pv0 = self._wa._local_minimization.x
682702
x0 = self._tsa.ps.sample_from_prior(npop)
683-
x0[:, 0] = normal(pv0[2], 0.05, size=npop)
684-
x0[:, 1] = normal(pv0[0], 1e-4, size=npop)
685-
x0[:, 2] = normal(pv0[1], 1e-5, size=npop)
686-
x0[:, 3] = clip(normal(pv0[3], 0.01, size=npop), 0.0, 1.0)
703+
x0[:, 0] = clip(normal(pv0[1], 0.05, size=npop), 0.01, inf)
704+
x0[:, 1] = clip(normal(pv0[0], 1e-4, size=npop), 0.01, inf)
705+
x0[:, 2] = clip(normal(pv0[2], 1e-3, size=npop), 0.0, 1.0)
706+
707+
nep = max(self.data.epoch_groups) + 1
708+
for i in range(nep):
709+
pida = self.ps.find_pid(f'tc_{i:02d}')
710+
pidb = self._wa.ps.find_pid(f'tc_{i:02d}')
711+
x0[:, pida] = normal(pv0[pidb], 0.001, size=npop)
712+
687713
sl = self._tsa._sl_rratios
688-
x0[:, sl] = normal(sqrt(pv0[4]), 0.001, size=(npop, self.nk))
714+
x0[:, sl] = normal(sqrt(pv0[self._wa.ps.find_pid('k2')]), 0.001, size=(npop, self.nk))
689715
for i in range(sl.start, sl.stop):
690716
x0[:, i] = clip(x0[:, i], 1.001*self.ps[i].prior.a, 0.999*self.ps[i].prior.b)
691717

0 commit comments

Comments
 (0)