|
32 | 32 | from celerite2 import GaussianProcess, terms
|
33 | 33 | from emcee import EnsembleSampler
|
34 | 34 | from matplotlib.pyplot import subplots, setp, figure, Figure, Axes
|
35 |
| -from numpy import (where, sqrt, clip, percentile, median, squeeze, floor, ndarray, |
| 35 | +from numpy import (any, where, sqrt, clip, percentile, median, squeeze, floor, ndarray, isfinite, |
36 | 36 | array, inf, newaxis, arange, tile, sort, argsort, concatenate, full, nan, r_, nanpercentile, log10,
|
37 |
| - ceil) |
| 37 | + ceil, unique) |
38 | 38 | from numpy.random import normal, permutation
|
39 | 39 | from pytransit import UniformPrior, NormalPrior
|
40 | 40 | from pytransit.orbits import epoch
|
@@ -144,6 +144,26 @@ def __init__(self, name: str, ldmodel, data: TSDataSet | TSData, nk: int = 50, n
|
144 | 144 | The noise model to use. Should be either "white" for white noise or "fixed_gp" for Gaussian Process.
|
145 | 145 | """
|
146 | 146 | data = TSDataSet([data]) if isinstance(data, TSData) else data
|
| 147 | + |
| 148 | + for d in data: |
| 149 | + if any(~isfinite(d.fluxes[d.mask])): |
| 150 | + raise ValueError(f"The {d.name} data set flux array contains unmasked noninfinite values.") |
| 151 | + |
| 152 | + if any(~isfinite(d.errors[d.mask])): |
| 153 | + raise ValueError(f"The {d.name} data set error array contains unmasked noninfinite values.") |
| 154 | + |
| 155 | + ngs = array(data.noise_groups) |
| 156 | + if not ((ngs.min() == 0) and (ngs.max() + 1 == unique(ngs).size)): |
| 157 | + raise ValueError("The noise groups must start from 0 and be consecutive.") |
| 158 | + |
| 159 | + ogs = array(data.offset_groups) |
| 160 | + if not ((ogs.min() == 0) and (ogs.max() + 1 == unique(ogs).size)): |
| 161 | + raise ValueError("The offset groups must start from 0 and be consecutive.") |
| 162 | + |
| 163 | + egs = array(data.epoch_groups) |
| 164 | + if not ((egs.min() == 0) and (egs.max() + 1 == unique(egs).size)): |
| 165 | + raise ValueError("The epoch groups must start from 0 and be consecutive.") |
| 166 | + |
147 | 167 | self._tsa: TSLPF = TSLPF(name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
|
148 | 168 | noise_model=noise_model, interpolation=interpolation)
|
149 | 169 | self._wa: WhiteLPF | None = None
|
@@ -542,10 +562,10 @@ def fit_white(self, niter: int = 500) -> None:
|
542 | 562 | self._wa.optimize_global(niter, plot_convergence=False, use_tqdm=False)
|
543 | 563 | self._wa.optimize()
|
544 | 564 | pv = self._wa._local_minimization.x
|
545 |
| - self.period = pv[1] |
| 565 | + self.period = pv[0] |
546 | 566 | self.zero_epoch = self._wa.transit_center
|
547 | 567 | self.transit_duration = self._wa.transit_duration
|
548 |
| - self.data.mask_transit(pv[0], pv[1], self.transit_duration) |
| 568 | + self.data.mask_transit(self.zero_epoch, self.period, self.transit_duration) |
549 | 569 |
|
550 | 570 | def plot_white(self, axs=None, figsize: tuple[float, float] | None = None, ncols: int | None=None) -> Figure:
|
551 | 571 | """Plot the white light curve data with the best-fit model.
|
@@ -680,12 +700,18 @@ def fit(self, niter: int = 200, npop: Optional[int] = None, pool: Optional[Pool]
|
680 | 700 |
|
681 | 701 | pv0 = self._wa._local_minimization.x
|
682 | 702 | x0 = self._tsa.ps.sample_from_prior(npop)
|
683 |
| - x0[:, 0] = normal(pv0[2], 0.05, size=npop) |
684 |
| - x0[:, 1] = normal(pv0[0], 1e-4, size=npop) |
685 |
| - x0[:, 2] = normal(pv0[1], 1e-5, size=npop) |
686 |
| - x0[:, 3] = clip(normal(pv0[3], 0.01, size=npop), 0.0, 1.0) |
| 703 | + x0[:, 0] = clip(normal(pv0[1], 0.05, size=npop), 0.01, inf) |
| 704 | + x0[:, 1] = clip(normal(pv0[0], 1e-4, size=npop), 0.01, inf) |
| 705 | + x0[:, 2] = clip(normal(pv0[2], 1e-3, size=npop), 0.0, 1.0) |
| 706 | + |
| 707 | + nep = max(self.data.epoch_groups) + 1 |
| 708 | + for i in range(nep): |
| 709 | + pida = self.ps.find_pid(f'tc_{i:02d}') |
| 710 | + pidb = self._wa.ps.find_pid(f'tc_{i:02d}') |
| 711 | + x0[:, pida] = normal(pv0[pidb], 0.001, size=npop) |
| 712 | + |
687 | 713 | sl = self._tsa._sl_rratios
|
688 |
| - x0[:, sl] = normal(sqrt(pv0[4]), 0.001, size=(npop, self.nk)) |
| 714 | + x0[:, sl] = normal(sqrt(pv0[self._wa.ps.find_pid('k2')]), 0.001, size=(npop, self.nk)) |
689 | 715 | for i in range(sl.start, sl.stop):
|
690 | 716 | x0[:, i] = clip(x0[:, i], 1.001*self.ps[i].prior.a, 0.999*self.ps[i].prior.b)
|
691 | 717 |
|
|
0 commit comments