diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 360272373..7f172e3da 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -37,11 +37,10 @@ tfp_available = False -__all__ = ["get_kernel", "get_mean", "get_prior", - "get_log_likelihood", "GPResult", "get_gp_params"] +__all__ = ["get_kernel", "get_mean", "get_prior", "get_log_likelihood", "GPResult", "get_gp_params"] -def get_kernel(kernel_type, kernel_params): +def get_kernel(kernel_type, kernel_params, p=1, q=0): """ Function for producing the kernel for the Gaussian Process. Returns the selected Tinygp kernel for the given parameters. @@ -51,7 +50,7 @@ def get_kernel(kernel_type, kernel_params): kernel_type: string The type of kernel to be used for the Gaussian Process To be selected from the kernels already implemented: - ["RN", "QPO", "QPO_plus_RN"] + ["RN", "QPO", "QPO_plus_RN", "CARMA", "QPO_plus_CARMA"] kernel_params: dict Dictionary containing the parameters for the kernel @@ -87,6 +86,27 @@ def get_kernel(kernel_type, kernel_params): d=2 * jnp.pi * kernel_params["freq"], ) return kernel + elif kernel_type == "CARMA": + alpha = kernel_params["alpha"] + beta = kernel_params["beta"] + + # jax.debug.print("alpha in get_kernel: {}", kernel_params["alpha"]) + # jax.debug.print("beta in get_kernel: {}", kernel_params["beta"]) + acarma = kernel_params["acarma"] + kernel = kernels.quasisep.CARMA.init(alpha=alpha, beta=beta, sigma=acarma**0.5) + return kernel + elif kernel_type == "QPO_plus_CARMA": + kernel = kernels.quasisep.CARMA.init( + alpha=kernel_params["alpha"], + beta=kernel_params["beta"], + sigma=kernel_params["acarma"] ** 0.5, + ) + kernels.quasisep.Celerite( + a=kernel_params["aqpo"], + b=0.0, + c=kernel_params["cqpo"], + d=2 * jnp.pi * kernel_params["freq"], + ) + return kernel else: raise ValueError("Kernel type not implemented") @@ -366,6 +386,10 @@ def _get_kernel_params(kernel_type): return ["log_arn", "log_crn", "log_aqpo", "log_cqpo", "log_freq"] elif kernel_type == "QPO": return ["log_aqpo", "log_cqpo", "log_freq"] + elif kernel_type == "CARMA": + return ["log_alpha", "log_beta", "log_acarma"] + elif kernel_type == "QPO_plus_CARMA": + return ["log_alpha", "log_beta", "log_acarma", "log_aqpo", "log_cqpo", "log_freq"] else: raise ValueError("Kernel type not implemented") @@ -427,7 +451,7 @@ def get_gp_params(kernel_type, mean_type): return kernel_params -def get_prior(params_list, prior_dict): +def get_prior(params_list, prior_dict, p=1, q=0): """ A prior generator function based on given values. Makes a jaxns specific prior function based on the given prior dictionary. @@ -447,6 +471,9 @@ def get_prior(params_list, prior_dict): the corresponding name in the params_list. Also, if a parameter is to be used in the log scale, it should have a prefix of "log_" + p, q: int, int, default 1 + The orders for the CARMA process, must be p >= q + Returns ------- The Prior generator function. @@ -485,20 +512,41 @@ def get_prior(params_list, prior_dict): def prior_model(): prior_list = [] - for i in params_list: - if isinstance(prior_dict[i], tfpd.Distribution): - parameter = yield Prior(prior_dict[i], name=i) - elif isinstance(prior_dict[i], Prior): - parameter = yield prior_dict[i] + for param in params_list: + if param == "log_alpha": + for j in range(p): + if isinstance(prior_dict[param], tfpd.Distribution): + parameter = yield Prior(prior_dict[param], name=param + str(j)) + elif isinstance(prior_dict[param], Prior): + parameter = yield prior_dict[param] + else: + raise ValueError("Prior should be a tfpd distribution or a jaxns prior.") + prior_list.append(parameter) + elif param == "log_beta": + for k in range(q): + if isinstance(prior_dict[param], tfpd.Distribution): + parameter = yield Prior(prior_dict[param], name=param + str(k)) + elif isinstance(prior_dict[param], Prior): + parameter = yield prior_dict[param] + else: + raise ValueError("Prior should be a tfpd distribution or a jaxns prior.") + prior_list.append(parameter) else: - raise ValueError("Prior should be a tfpd distribution or a jaxns prior.") - prior_list.append(parameter) + if isinstance(prior_dict[param], tfpd.Distribution): + parameter = yield Prior(prior_dict[param], name=param + str(k)) + elif isinstance(prior_dict[param], Prior): + parameter = yield prior_dict[param] + else: + raise ValueError("Prior should be a tfpd distribution or a jaxns prior.") + prior_list.append(parameter) return tuple(prior_list) return prior_model -def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwargs): +def get_log_likelihood( + params_list, kernel_type, mean_type, times, counts, counts_err=None, p=1, q=0 +): """ A log likelihood generator function based on given values. Makes a jaxns specific log likelihood function which takes in the @@ -530,6 +578,10 @@ def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwa counts: np.array or jnp.array The photon counts array of the lightcurve + counts_err : np.array or jnp.array + The uncertainties on the counts or fluxes given + in `counts`. Must be of same shape as `counts` + Returns ------- The Jaxns specific log likelihood function. @@ -541,18 +593,41 @@ def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwa if not can_make_gp: raise ImportError("Tinygp is required to make the GP model.") + if counts_err is None: + counts_err = jnp.zeros_like(counts) + @jit def likelihood_model(*args): + # print(f"p: {p}") + # print(f"q: {q}") param_dict = {} - for i, params in enumerate(params_list): - if params[0:4] == "log_": - param_dict[params[4:]] = jnp.exp(args[i]) + i = 0 + + for params in params_list: + if params == "alpha": + param_dict["alpha"] = args[i : i + p] + i += p + elif params == "log_alpha": + param_dict["alpha"] = jnp.exp(jnp.array(args[i : i + p])) + i += p + elif params == "beta": + param_dict["beta"] = args[i : i + q] + i += q + elif params == "log_beta": + param_dict["beta"] = jnp.exp(jnp.array(args[i : i + q])) + i += q else: - param_dict[params] = args[i] + if params[0:4] == "log_": + param_dict[params[4:]] = jnp.exp(args[i]) + else: + param_dict[params] = args[i] + i += 1 + kernel = get_kernel(kernel_type=kernel_type, kernel_params=param_dict) mean = get_mean(mean_type=mean_type, mean_params=param_dict) - gp = GaussianProcess(kernel, times, mean_value=mean(times)) - return gp.log_probability(counts) + gp = GaussianProcess(kernel, times, mean_value=mean(times), diag=counts_err) + log_like = jnp.nan_to_num(gp.log_probability(counts), nan=-1e20, posinf=-1e20, neginf=-1e20) + return log_like return likelihood_model @@ -573,10 +648,13 @@ def __init__(self, lc: Lightcurve) -> None: self.lc = lc self.time = lc.time self.counts = lc.counts + + if lc.err_dist == "poisson": + self.counts_err = jnp.sqrt(self.counts) + self.result = None - def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4, - num_live_points=500): + def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4, num_live_points=500): """ Makes a Jaxns nested sampler over the Gaussian Process, given the prior and likelihood model @@ -619,7 +697,9 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4, nsmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) nsmodel.sanity_check(random.PRNGKey(10), S=100) - self.exact_ns = ExactNestedSampler(nsmodel, num_live_points=num_live_points, max_samples=max_samples) + self.exact_ns = ExactNestedSampler( + nsmodel, num_live_points=num_live_points, max_samples=max_samples + ) termination_reason, state = self.exact_ns( random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) @@ -904,3 +984,203 @@ def comparison_plot( plt.savefig(filename) return plt + + def plot_posterior_predictive( + self, + lc, + kernel_type, + mean_type, + nmean=50, + ngrid=1000, + ax=None, + rkey=None, + p=1, + q=0, + counts_err=None, + ): + """ + Plot the posterior predictive distribution. + Will plot the maximum posterior for the Gaussian Process, and + `nsamples` random draws from the mean function. + + Parameters + ---------- + lc : stingray.Lightcurve object + The light curve with the time series data + being modelled + + kernel_type : str + The kernel type used in the modeling + + mean_type : str + The type of the mean function used + + nmean: int, default 50 + The number of samples to use for drawing the + posterior + + ngrid: int, default 1000 + The number of points in the linear grid to + use for plotting the Gaussian Process + + ax : matplotlib.Axes object, default None + A matplotlib.Axes object to plot into. If none is + given, a new Figure object will be created + + rkey : jax.randomPRNGKey object + A random key for setting the sampling. If None, + set to random.PRNGKey(1234) + + p, q : int, int, default 1, 0 + If the kernel involves a CARMA model, then this + sets the orders of the AR and MA processes involved. + Note that p >= q is required + + counts_err : array, default None + if `None`, the error bars for the data will be set + to `sqrt(lc.counts)`, otherwise to whatever is set + for `counts_err` + + Returns + ------- + ax : matplotlib.Axes object + The matplotlib.axes object that the plot is + drawn in + """ + if rkey is None: + rkey = random.PRNGKey(1234) + + log_p = self.results.log_dp_mean # log-prob + nsamples = self.results.total_num_samples # number of samples + + # array for resampled samples + samples_resampled = {} + + # go through samples, resample with weights to get + # a weighted posterior sample + for name in self.results.samples.keys(): + samples = self.results.samples[name] + + weights = jnp.where(jnp.isfinite(samples), jnp.exp(log_p), 0.0) + log_weights = jnp.where(jnp.isfinite(samples), log_p, -jnp.inf) + sr = resample( + rkey, samples, log_weights, S=max(10, int(self.results.ESS)), replace=True + ) + samples_resampled[name] = sr + + print("Resampling done, calculating maximum posterior model ...") + # split into samples belonging to the kernel, and samples + # belonging to the mean function + kernel_params = stingray.modeling.gpmodeling._get_kernel_params(kernel_type) + mean_params = stingray.modeling.gpmodeling._get_mean_params(mean_type) + + maxpost_log = self.get_max_posterior_parameters() + + if not "CARMA" in kernel_type: + mean_samples = dict((k, samples_resampled[k]) for k in mean_params) + sk_maxpost_log = dict((k, maxpost_log[k]) for k in kernel_params) + sm_maxpost_log = dict((k, maxpost_log[k]) for k in mean_params) + else: + mean_samples = dict((k, samples_resampled[k + "0"]) for k in mean_params) + + kernel_params_small = test_list = [ + i for i in kernel_params if not ("alpha" in i or "beta" in i) + ] + sk_maxpost_log = dict((k, maxpost_log[k + "0"]) for k in kernel_params_small) + + alpha, beta = [], [] + for j in range(p): + print(j) + if "log_alpha" in kernel_params: + alpha.append(maxpost_log["log_alpha" + str(j)]) + else: + alpha.append(maxpost_log["alpha" + str(j)]) + for k in range(q): + print(k) + if "log_beta" in kernel_params: + beta.append(maxpost_log["log_beta" + str(k)]) + else: + beta.append(maxpost_log["beta" + str(k)]) + + sk_maxpost_log["log_alpha"] = alpha + sk_maxpost_log["log_beta"] = beta + + sm_maxpost_log = dict((k, maxpost_log[k + "0"]) for k in mean_params) + + sk_maxpost, sm_maxpost = {}, {} + for params in kernel_params: + if params[0:4] == "log_": + sk_maxpost[params[4:]] = np.exp(sk_maxpost_log[params]) + else: + sk_maxpost[params] = sk_maxpost_log[params] + + for params in mean_params: + if params[0:4] == "log_": + sm_maxpost[params[4:]] = np.exp(sm_maxpost_log[params]) + else: + sm_maxpost[params] = sm_maxpost_log[params] + + kernel = get_kernel(kernel_type=kernel_type, kernel_params=sk_maxpost) + mean = get_mean(mean_type=mean_type, mean_params=sm_maxpost) + + gp = GaussianProcess(kernel, lc.time, mean_value=mean(lc.time)) + tgrid = np.linspace(lc.time[0], lc.time[-1], ngrid) + _, cond = gp.condition(lc.counts, tgrid) + + mu = cond.loc + mean(tgrid) + std = np.sqrt(cond.variance) + + print("GP calculated, plotting GP and data ...") + + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=(8, 4)) + + if counts_err is None: + counts_err = np.sqrt(lc.counts) + ax.errorbar( + lc.time, + lc.counts, + yerr=counts_err, + fmt="o", + markersize=2, + color="black", + label="Observations", + ) + ax.plot(tgrid, mu, color="C0", label="Gaussian Process Maximum Posterior") + ax.fill_between(tgrid, mu + std, mu - std, color="C0", alpha=0.3) + + idx_all = np.random.choice( + np.arange(0, int(gpresult.results.ESS), 1.0, dtype=int), size=nmean, replace=False + ) + print("calculating and plotting mean functions ...") + for i, idx in enumerate(idx_all): + # sk_log = dict((k, kernel_samples[k][idx]) for k in kernel_params) + sm_log = dict((k, mean_samples[k][idx]) for k in mean_params) + sm = {} + for params in mean_params: + if params[0:4] == "log_": + sm[params[4:]] = jnp.exp(sm_log[params]) + else: + sm[params] = sm_log[params] + + mean = get_mean(mean_type=mean_type, mean_params=sm) + mean_vals = mean(tgrid) + + # legend only for the first line being drawn + if i == 0: + ax.plot( + tgrid, + mean_vals, + color="orange", + alpha=0.1, + label="Mean function posterior draws", + ) + else: + ax.plot(tgrid, mean_vals, color="orange", alpha=0.1) + + # update legend opacity + leg = ax.legend() + for lh in leg.legendHandles: + lh.set_alpha(1) + + return ax