Skip to content

Commit ee2461a

Browse files
committed
Added the support for free GP noise model.
1 parent 4e253b9 commit ee2461a

File tree

1 file changed

+19
-4
lines changed

1 file changed

+19
-4
lines changed

exoiris/tslpf.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def clean_knots(knots, min_distance, lmin=0, lmax=inf):
125125

126126
class TSLPF(LogPosteriorFunction):
127127
def __init__(self, name: str, ldmodel, data: TSDataSet, nk: int = 50, nldc: int = 10, nthreads: int = 1,
128-
tmpars = None, noise_model: str = 'white',
128+
tmpars = None, noise_model: Literal["white", "fixed_gp", "free_gp"] = 'white',
129129
interpolation: Literal['bspline', 'pchip', 'makima'] = 'bspline'):
130130
super().__init__(name)
131131
self._original_data: TSDataSet | None = None
@@ -206,6 +206,8 @@ def _init_parameters(self) -> None:
206206
self._init_p_limb_darkening()
207207
self._init_p_radius_ratios()
208208
self._init_p_noise()
209+
if self._nm == NM_GP_FREE:
210+
self._init_p_gp()
209211
self._init_p_baseline()
210212
self._init_p_bias()
211213
self.ps.freeze()
@@ -230,6 +232,10 @@ def set_noise_model(self, noise_model: str) -> None:
230232
self._nm = noise_models[noise_model]
231233
if self._nm in (NM_GP_FIXED, NM_GP_FREE):
232234
self._init_gp()
235+
if self._nm == NM_GP_FREE:
236+
self.ps.thaw()
237+
self._init_p_gp()
238+
self.ps.freeze()
233239

234240
def _init_gp(self) -> None:
235241
"""Initializes the Gaussian Process (GP) .
@@ -329,6 +335,15 @@ def _init_p_noise(self):
329335
self._start_wnm = ps.blocks[-1].start
330336
self._sl_wnm = ps.blocks[-1].slice
331337

338+
def _init_p_gp(self):
339+
ps = self.ps
340+
if not hasattr(self, '_sl_gp'):
341+
pp = [GParameter('gp_log_sigma', 'GP log sigma', '', NP(0.0, 0.01), (-inf, inf)),
342+
GParameter('gp_log_rho', 'GP log rho', '', NP(0.0, 0.01), (-inf, inf))]
343+
ps.add_global_block('gp_hyperparameters', pp)
344+
self._start_gp = ps.blocks[-1].start
345+
self._sl_gp = ps.blocks[-1].slice
346+
332347
def _init_p_baseline(self):
333348
ps = self.ps
334349
self.n_baselines = self.data.n_baselines
@@ -664,12 +679,12 @@ def lnlikelihood(self, pv) -> ndarray | float :
664679
if self._nm == NM_WHITE:
665680
for i, d in enumerate(self.data):
666681
lnl += lnlike_normal(d.fluxes, fmod[i], d.errors, wn_multipliers[:, d.ngid], d.mask)
667-
elif self._nm == NM_GP_FIXED:
682+
else:
668683
for j in range(npv):
684+
if self._nm == NM_GP_FREE:
685+
self.set_gp_hyperparameters(*pv[j, self._sl_gp])
669686
for i in range(self.data.size):
670687
lnl[j] += self._gp[i].log_likelihood(self._gp_flux[i] - fmod[i][j][self.data[i].mask])
671-
else:
672-
raise NotImplementedError("The free GP noise model hasn't been implemented yet.")
673688
return lnl if npv > 1 else lnl[0]
674689

675690
def create_initial_population(self, n: int, source: str, add_noise: bool = True) -> ndarray:

0 commit comments

Comments
 (0)