Skip to content

Commit c29ac6b

Browse files
committed
Added dataset-specific GP hyperparameters support and refined data handling
- Enhanced `set_gp_hyperparameters` to allow dataset-specific hyperparameters via `idata` parameter. - Improved flux and error handling in `wlpf` by ensuring finite data treatment. - Fixed initialization logic for `optimize` method in `wlpf`. - Adjusted setup order for data and noise models in `tslpf` for consistency. - Streamlined GP initialization logic by enhancing condition handling.
1 parent 921b810 commit c29ac6b

File tree

3 files changed

+16
-18
lines changed

3 files changed

+16
-18
lines changed

exoiris/exoiris.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def set_ldtk_prior(self,
303303
metal = (metal.n, metal.s) if isinstance(metal, UFloat) else metal
304304
self._tsa.set_ldtk_prior(teff, logg, metal, dataset, width, uncertainty_multiplier)
305305

306-
def set_gp_hyperparameters(self, sigma: float, rho: float) -> None:
306+
def set_gp_hyperparameters(self, sigma: float, rho: float, idata: None | int = None) -> None:
307307
"""Set Gaussian Process (GP) hyperparameters assuming a Matern-3/2 kernel.
308308
309309
Parameters
@@ -312,8 +312,10 @@ def set_gp_hyperparameters(self, sigma: float, rho: float) -> None:
312312
The kernel amplitude parameter.
313313
rho
314314
The length scale parameter.
315+
idata
316+
The data set for which to set the hyperparameters. If None, the hyperparameters are set for all data sets.
315317
"""
316-
self._tsa.set_gp_hyperparameters(sigma, rho)
318+
self._tsa.set_gp_hyperparameters(sigma, rho, idata)
317319

318320
def set_gp_kernel(self, kernel: terms.Term) -> None:
319321
"""Set the Gaussian Process (GP) kernel.

exoiris/tslpf.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -142,13 +142,11 @@ def __init__(self, name: str, ldmodel, data: TSDataSet, nk: int = 50, nldc: int
142142
self._gp_time: Optional[list[ndarray]] = None
143143
self._gp_flux: Optional[list[ndarray]] = None
144144

145-
self.set_noise_model(noise_model)
146-
147-
self.ldmodel = ldmodel
148-
149145
self.tms = [TSModel(ldmodel, nthreads=nthreads, **(tmpars or {})) for i in range(len(data))]
150146
self.set_data(data)
147+
self.set_noise_model(noise_model)
151148

149+
self.ldmodel = ldmodel
152150
if isinstance(ldmodel, LDTkLD):
153151
for tm in self.tms:
154152
tm.ldmodel = None
@@ -196,8 +194,6 @@ def set_data(self, data: TSDataSet):
196194
self.npt: list[int] = [f.shape[1] for f in self.flux]
197195
for i, time in enumerate(self.times):
198196
self.tms[i].set_data(time)
199-
if self._nm in (NM_GP_FIXED, NM_GP_FREE):
200-
self._init_gp()
201197

202198
def _init_parameters(self) -> None:
203199
self.ps = ParameterSet([])
@@ -276,7 +272,7 @@ def set_gp_hyperparameters(self, sigma: float, rho: float, idata: int | None = N
276272
if self._gp is None:
277273
raise RuntimeError('The GP needs to be initialized before setting hyperparameters.')
278274

279-
for i in ([idata] or range(self.data.size)):
275+
for i in ([idata] if idata is not None else range(self.data.size)):
280276
self._gp[i].kernel = terms.Matern32Term(sigma=sigma, rho=rho)
281277
self._gp[i].compute(self._gp_time[i], yerr=self._gp_ferr[i], quiet=True)
282278

exoiris/wlpf.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ def __init__(self, tsa: TSLPF):
4949
fluxes, times, errors = [], [], []
5050
for t, f, e in zip(tsa.data.times, tsa.data.fluxes, tsa.data.errors):
5151
weights = where(isfinite(f) & isfinite(e), 1/e**2, 0.0)
52-
mf = average(f, axis=0, weights=weights)
53-
me = sqrt(1./(1./e**2).sum(0))
52+
mf = average(where(isfinite(f), f, 0), axis=0, weights=weights)
53+
me = sqrt(1 / weights.sum(0))
5454
m = isfinite(mf)
5555
times.append(t[m])
5656
fluxes.append(mf[m])
@@ -108,13 +108,13 @@ def transit_model(self, pv, copy=True):
108108
return self.tm.evaluate(radius_ratio, ldc, zero_epoch, period, smaxis, inclination)
109109

110110
def optimize(self, pv0=None, method='powell', maxfev: int = 5000):
111-
if pv0 is None:
112-
if self.de is not None:
113-
pv0 = self.de.minimum_location
114-
else:
115-
pv0 = self.ps.mean_pv
116-
res = minimize(lambda pv: -self.lnposterior(pv), pv0, method=method, options={'maxfev':maxfev})
117-
self._local_minimization = res
111+
if pv0 is None:
112+
if self.de is not None:
113+
pv0 = self.de.minimum_location
114+
else:
115+
pv0 = self.ps.mean_pv
116+
res = minimize(lambda pv: -self.lnposterior(pv), pv0, method=method, options={'maxfev':maxfev})
117+
self._local_minimization = res
118118

119119
@property
120120
def transit_center(self):

0 commit comments

Comments
 (0)