Skip to content

Commit 7bbefa1

Browse files
committed
feat(exoiris): added support for handling white light curve models
- Implemented reading and storing white light curve models. - Adjusted white light curve properties to accommodate new functionality. - Enhanced FITS export to include white curve data and parameters if available.
1 parent 41add4e commit 7bbefa1

File tree

1 file changed

+49
-10
lines changed

1 file changed

+49
-10
lines changed

exoiris/exoiris.py

Lines changed: 49 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,16 @@ def load_model(fname: Path | str, name: str | None = None):
9292
a.set_radius_ratio_knots(hdul['K_KNOTS'].data.astype('d'))
9393
a.set_limb_darkening_knots(hdul['LD_KNOTS'].data.astype('d'))
9494

95+
# Read the white light curve models if they exist.
96+
try:
97+
tb = Table.read(hdul['WHITE_DATA'])
98+
white_ids = tb['id'].data
99+
model_flux = tb['mod_flux'].data
100+
uids = unique(white_ids)
101+
a._white_models = [model_flux[white_ids == i] for i in uids]
102+
except KeyError:
103+
pass
104+
95105
try:
96106
a.period = hdul[0].header['P']
97107
a.zero_epoch = hdul[0].header['T0']
@@ -161,15 +171,17 @@ def __init__(self, name: str, ldmodel, data: TSDataGroup | TSData, nk: int = 50,
161171
if not ((egs.min() == 0) and (egs.max() + 1 == unique(egs).size)):
162172
raise ValueError("The epoch groups must start from 0 and be consecutive.")
163173

164-
self._tsa: TSLPF = TSLPF(self, name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
174+
self._tsa = TSLPF(self, name, ldmodel, data, nk=nk, nldc=nldc, nthreads=nthreads, tmpars=tmpars,
165175
noise_model=noise_model, interpolation=interpolation)
166-
self._wa: WhiteLPF | None = None
176+
self._wa = WhiteLPF(self._tsa)
177+
167178
self.nthreads: int = nthreads
168179

169180
self.period: float | None = None
170181
self.zero_epoch: float | None = None
171182
self.transit_duration: float | None= None
172183
self._tref = floor(self.data.tmin)
184+
self._white_models: None | list[ndarray] = None
173185

174186
def lnposterior(self, pvp: ndarray) -> ndarray:
175187
"""Calculate the log posterior probability for a single parameter vector or an array of parameter vectors.
@@ -411,8 +423,11 @@ def white_fluxes(self) -> list[ndarray]:
411423
@property
412424
def white_models(self) -> list[ndarray]:
413425
"""Fitted white light curve flux model arrays."""
414-
fm = self._wa.flux_model(self._wa._local_minimization.x)
415-
return [fm[sl] for sl in self._wa.lcslices]
426+
if self._wa._local_minimization is not None:
427+
fm = self._wa.flux_model(self._wa._local_minimization.x)
428+
return [fm[sl] for sl in self._wa.lcslices]
429+
else:
430+
return self._white_models
416431

417432
@property
418433
def white_errors(self) -> list[ndarray]:
@@ -528,7 +543,6 @@ def fit_white(self, niter: int = 500) -> None:
528543
niter : int, optional
529544
The number of iterations for the global optimization algorithm (default is 500).
530545
"""
531-
self._wa = WhiteLPF(self._tsa)
532546
self._wa.optimize_global(niter, plot_convergence=False, use_tqdm=False)
533547
self._wa.optimize()
534548
pv = self._wa._local_minimization.x
@@ -1081,6 +1095,34 @@ def save(self, overwrite: bool = False) -> None:
10811095
hdul = pf.HDUList([pri, k_knots, ld_knots, pr])
10821096
hdul += self.data.export_fits()
10831097

1098+
if self._wa._local_minimization is not None:
1099+
wa_data = pf.BinTableHDU(
1100+
Table(
1101+
[
1102+
self._wa.lcids,
1103+
self._wa.timea,
1104+
concatenate(self.white_models),
1105+
self._wa.ofluxa,
1106+
concatenate(self._wa.std_errors),
1107+
],
1108+
names="id time mod_flux obs_flux obs_error".split(),
1109+
), name='white_data'
1110+
)
1111+
hdul.append(wa_data)
1112+
1113+
names = []
1114+
counts = {}
1115+
for p in self._wa.ps.names:
1116+
if p not in counts.keys():
1117+
counts[p] = 0
1118+
names.append(p)
1119+
else:
1120+
counts[p] += 1
1121+
names.append(f'{p}_{counts[p]}')
1122+
1123+
wa_params = pf.BinTableHDU(Table(self._wa._local_minimization.x, names=names), name='white_params')
1124+
hdul.append(wa_params)
1125+
10841126
if self._tsa.de is not None:
10851127
de = pf.BinTableHDU(Table(self._tsa._de_population, names=self.ps.names), name='DE')
10861128
de.header['npop'] = self._tsa.de.n_pop
@@ -1164,9 +1206,6 @@ def optimize_gp_hyperparameters(self,
11641206
if self._tsa.noise_model not in ('fixed_gp', 'free_gp'):
11651207
raise ValueError("The noise model must be set to 'fixed_gp' or 'free_gp' before the hyperparameter optimization.")
11661208

1167-
if self._wa is None:
1168-
raise ValueError("The white light curves must be fit using 'fit_white()' before the hyperparameter optimization.")
1169-
11701209
if log10_rho_prior is not None:
11711210
if isinstance(log10_rho_prior, Sequence):
11721211
rp = norm(*log10_rho_prior)
@@ -1192,15 +1231,15 @@ def optimize_gp_hyperparameters(self,
11921231

11931232
match log10_sigma_bounds:
11941233
case None:
1195-
sb = [log10_sigma_guess-1, log10_sigma_guess+1]
1234+
sb = [log10_sigma_guess - 1, log10_sigma_guess + 1]
11961235
case _ if isinstance(log10_sigma_bounds, Sequence):
11971236
sb = log10_sigma_bounds
11981237
case _ if isinstance(log10_sigma_bounds, float):
11991238
sb = [log10_sigma_bounds-1, log10_sigma_bounds+1]
12001239

12011240
match log10_rho_bounds:
12021241
case None:
1203-
rb = [-5, -2]
1242+
rb = [-5, -2]
12041243
case _ if isinstance(log10_rho_bounds, Sequence):
12051244
rb = log10_rho_bounds
12061245
case _ if isinstance(log10_rho_bounds, float):

0 commit comments

Comments
 (0)