diff --git a/stingray/modeling/gpmodeling.py b/stingray/modeling/gpmodeling.py index 0c5dbe5b1..39a5f5305 100644 --- a/stingray/modeling/gpmodeling.py +++ b/stingray/modeling/gpmodeling.py @@ -3,6 +3,9 @@ import functools from stingray import Lightcurve +from matplotlib.lines import Line2D +from matplotlib.patches import Patch + try: import jax import jax.numpy as jnp @@ -21,7 +24,7 @@ can_make_gp = False try: - from jaxns import ExactNestedSampler, TerminationCondition, Prior, Model + from jaxns import DefaultNestedSampler, TerminationCondition, Prior, Model from jaxns.utils import resample can_sample = True @@ -37,10 +40,607 @@ tfp_available = False -__all__ = ["get_kernel", "get_mean", "get_prior", "get_log_likelihood", "GPResult", "get_gp_params"] +__all__ = [ + "get_kernel", + "get_mean", + "get_prior", + "get_log_likelihood", + "GPResult", + "get_gp_params", + "run_prior_checks", + "run_posterior_check", +] + + +def get_priors_samples(key, kernel_params, priors, loglike, n_samples=3000): + """Sample from the prior distribution. + + Parameters + ---------- + key : jax.random.PRNGKey + Random key. + priors : list + List of priors. + loglike : callable + Log likelihood function. + n_samples : int + Number of samples. Default is 3000. + """ + + num_params = len(kernel_params) + + # get the prior model + prior_dict = dict(zip(kernel_params, priors)) + prior_model = get_prior(kernel_params, prior_dict) + + # define the model + nsmodel = Model(prior_model=prior_model, log_likelihood=loglike) + # nsmodel.sanity_check(key=jax.random.PRNGKey(0), S=1) + + # get the samples + unit_samples = jax.random.uniform(key, (n_samples, num_params)) + prior_samples = jax.vmap(nsmodel.transform)(unit_samples) + + return prior_samples + + +def get_psd_and_approx( + kernel_type, + kernel_params, + prior_samples, + f0, + fM, + n_frequencies=1000, + n_approx_components=20, + approximate_with="SHO", + with_normalisation=False, +): # -> tuple[NDArray[Any], NDArray[Any]]: + """Get the PSD and the approximate PSD for a given set of parameters and samples. + + Parameters + ---------- + n_samples : int + Number of samples. + kernel_params : list[str] + List of kernel parameters. + prior_samples : NDArray + Prior samples. + f0 : float + Minimum frequency. + fM : float + Maximum frequency. + n_frequencies : int + Number of frequencies. + + """ + n_samples = prior_samples[kernel_params[0]].shape[0] + f = np.geomspace(f0, fM, n_frequencies) + psd_models = [] + psd_approx = [] + for k in range(n_samples): + param_dict = {} + for i, params in enumerate(kernel_params): + if params[0:4] == "log_": + param_dict[params[4:]] = jnp.exp(prior_samples[params][k]) + else: + param_dict[params] = prior_samples[params][k] + + psd_model, psd_SHO = get_psd_approx_samples( + f, + kernel_type, + param_dict, + f0, + fM, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + ) + if with_normalisation: + f_c, a = _get_coefficients_approximation( + kernel_type, + param_dict, + f0, + fM, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + ) + norm = np.sum(a * f_c) + psd_models.append(psd_model * param_dict["variance"] / norm) + psd_approx.append(psd_SHO * param_dict["variance"] / norm) + else: + psd_models.append(psd_model) + psd_approx.append(psd_SHO) + psd_models = np.array(psd_models) + psd_approx = np.array(psd_approx) + return f, psd_models, psd_approx + + +def plot_psd_ppc(f, psd_quantiles, psd_approx_quantiles, psd_noise_levels, f_min, f_max, path): + """Replot the PSD PPC plot. + + Parameters + ---------- + f : array + The frequency array. + psd_quantiles : array + The quantiles of the PSD model. + psd_approx_quantiles : array + The quantiles of the PSD approximation. + psd_noise_levels : array + The noise levels. + f_min : float + The minimum frequency. + f_max : float + The maximum frequency. + path : str + The path to save the figure. + """ + approx_color = "C6" + psd_color = "C3" + noise_color = "C5" + window_color = "k" + + fig, ax = plt.subplots(1, 1, figsize=(6, 4)) + + ax.loglog(f, psd_quantiles[:, 2], label="Median", color=psd_color) + ax.fill_between(f, psd_quantiles[:, 1], psd_quantiles[:, 3], color=psd_color, alpha=0.3) + ax.fill_between(f, psd_quantiles[:, 0], psd_quantiles[:, 4], color=psd_color, alpha=0.15) + ax.axhline(psd_noise_levels[0], color=noise_color, ls="-", label="Noise level") + ax.loglog(f, psd_approx_quantiles[:, 2], color=approx_color) + ax.fill_between( + f, + psd_approx_quantiles[:, 1], + psd_approx_quantiles[:, 3], + color=approx_color, + alpha=0.3, + ) + ax.fill_between( + f, + psd_approx_quantiles[:, 0], + psd_approx_quantiles[:, 4], + color=approx_color, + alpha=0.15, + ) + ax.axvline(f_min, color=window_color, ls=":") + ax.axvline(f_max, color=window_color, ls=":") + ax.update({"xlabel": r"Frequency $(d^{-1})$", "ylabel": "Power Spectral Density"}) + ax.set_xlim(np.min(f), np.max(f) / 10) + ax.set_ylim(np.min(psd_noise_levels) / 10) + + legend_elements = [ + Line2D([0], [0], color=psd_color, lw=2, label="PSD model"), + Line2D([0], [0], color=approx_color, lw=2, label="PSD approximation"), + Line2D([0], [0], color=noise_color, lw=2, label="Noise level"), + Patch(facecolor="k", edgecolor="k", alpha=0.1, label="95%"), + Patch(facecolor="k", edgecolor="k", alpha=0.4, label="68%"), + Line2D( + [0], [0], color=window_color, lw=2, ls=":", label=r"$f_\mathrm{min}, f_\mathrm{max}$" + ), + ] + + ax.legend( + handles=legend_elements, + ncol=2, + bbox_to_anchor=(0.5, -0.175), + loc="lower center", + bbox_transform=fig.transFigure, + ) + fig.tight_layout() + fig.savefig(f"{path}replot_psd_ppc.pdf", bbox_inches="tight") + + return fig + + +def run_posterior_check( + kernel_type, + kernel_params, + posterior_dict, + t, + y, + yerr, + S_low=20, + S_high=20, + n_frequencies=1000, + n_approx_components=20, + approximate_with="SHO", + path="./", +): + + f_min, f_max = 1 / (t[-1] - t[0]), 1 / (2 * np.min(np.diff(t))) + f0, fM = f_min / S_low, f_max * S_high + f, psd_models, psd_approx = get_psd_and_approx( + kernel_type, + kernel_params, + posterior_dict, + f0, + fM, + n_frequencies=n_frequencies, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + with_normalisation=True, + ) + + if "log_shift" in kernel_params: + psd_noise_levels = [2 * np.median((yerr / y) ** 2) * np.median(np.diff(t))] + else: + psd_noise_levels = [2 * np.median((yerr) ** 2) * np.median(np.diff(t))] + psd_quantiles = jnp.percentile(psd_models, jnp.array([2.5, 16, 50, 84, 97.5]), axis=0).T + psd_approx_quantiles = jnp.percentile(psd_approx, jnp.array([2.5, 16, 50, 84, 97.5]), axis=0).T + + fig = plot_psd_ppc( + f, psd_quantiles, psd_approx_quantiles, psd_noise_levels, f_min, f_max, path=path + ) + + +def run_prior_checks( + kernel_type, + kernel_params, + priors, + loglike, + f_min, + f_max, + seed=42, + path="./", + n_samples=3000, + n_frequencies=1000, + S_low=20, + S_high=20, + n_approx_components=20, + approximate_with="SHO", +): + """Check the approximation of the power spectrum density. + This function will plot various diagnostics on the residuals and the ratio + of the PSD and the approximate PSD. + + Parameters + ---------- + kernel_type : str + The type of kernel to be used for the Gaussian Process + Only designed for the following Power spectra ["PowL","DoubPowL] + kernel_params : dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + priors : list + List of priors. + loglike : callable + Log likelihood function. + f_min : float + Minimum frequency. + f_max : float + Maximum frequency. + path : str + Path to save the plots. Default is "./". + n_samples : int + Number of samples. Default is 3000. + n_frequencies : int + Number of frequencies. Default is 1000. + key : jax.random.PRNGKey + Random key. Default is jax.random.PRNGKey(42). + S_low : int + Low frequency scaling factor. Default is 20. + S_high : int + High frequency scaling factor. Default is 20. + """ + key = jax.random.PRNGKey(seed) + if kernel_type not in ["PowL", "DoubPowL"]: + raise ValueError("Only 'PowL' and 'DoubPowL' kernels need to be checked") + + f0, fM = f_min / S_low, f_max * S_high + prior_samples = get_priors_samples(key, kernel_params, priors, loglike, n_samples) + f, psd_models, psd_approx = get_psd_and_approx( + kernel_type, + kernel_params, + prior_samples, + f0, + fM, + n_frequencies=n_frequencies, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + ) + + residuals = psd_approx - psd_models + ratio = np.exp(np.log(psd_approx) - np.log(psd_models)) + + fig1, _ = plot_psd_approx_quantiles(f, f_min, f_max, residuals, ratio) + fig1.savefig(f"{path}/psd_check_approx_quantiles.png", dpi=300) + + fig2, _ = plot_boxplot_psd_approx(residuals, ratio) + fig2.savefig(f"{path}/psd_check_approx_boxplot.png", dpi=300) + return fig1, fig2 + + +def plot_boxplot_psd_approx(residuals, ratios): + """Plot the boxplot of the residuals and the ratios of the PSD and the approximate PSD.""" + meta_mean_res = np.mean(residuals, axis=1) + meta_median_res = np.median(residuals, axis=1) + meta_min_res = np.min(np.abs(residuals), axis=1) + meta_max_res = np.max(np.abs(residuals), axis=1) + + meta_mean_rat = np.mean(ratios, axis=1) + meta_median_rat = np.median(ratios, axis=1) + meta_min_rat = np.min(ratios, axis=1) + meta_max_rat = np.max(ratios, axis=1) + + fig, ax = plt.subplots(2, 1, figsize=(7, 5.5)) + ax[0].boxplot( + [meta_mean_res, meta_median_res, meta_min_res, meta_max_res], + positions=[1, 2, 3, 4], + flierprops=dict(marker=".", markersize=3), + ) + ax[0].set_xticks([]) + ax[0].axhline(0, color="C2", lw=2, ls=":") + + ax[1].boxplot( + [meta_mean_rat, meta_median_rat, meta_min_rat, meta_max_rat], + positions=[1, 2, 3, 4], + flierprops=dict(marker=".", markersize=3), + ) + ax[1].set_xticks([1, 2, 3, 4]) + ax[1].axhline(1, color="C2", lw=2, ls=":") + ax[1].set_xticklabels(["mean", "median", "minimum", "maximum"]) + ax[0].set_ylabel(r"$P_{\text{true}} - P_{\text{approx}} $") + ax[1].set_ylabel(r"$P_{\text{approx}} / P_{\text{true}} $") + fig.align_ylabels(ax) + # fig.tight_layout() + return fig, ax + + +def plot_psd_approx_quantiles(f, f_min, f_max, residuals, ratios): + """Plot the quantiles of the residuals and the ratios of the PSD and + the approximate PSD as a function of frequency.""" + res_quantiles = jnp.percentile(residuals, jnp.asarray([2.5, 16, 50, 84, 97.5]), axis=0) + rat_quantiles = jnp.percentile(ratios, jnp.asarray([2.5, 16, 50, 84, 97.5]), axis=0) + + colors = "C0" + fig, ax = plt.subplots(2, 1, sharex=True, figsize=(6.5, 4.5), gridspec_kw={"hspace": 0.1}) + + ax[0].fill_between(f, res_quantiles[0], res_quantiles[4], color=colors, alpha=0.25) + ax[0].fill_between(f, res_quantiles[1], res_quantiles[3], color=colors, alpha=0.5) + ax[0].plot(f, res_quantiles[2], color="black", lw=1) + ax[0].update( + { + "xscale": "log", + "yscale": "linear", + "ylabel": r"$P_{\text{true}} - P_{\text{approx}} $", + } + ) + + ax[0].axvline(f_min, color="black", linestyle="--") + ax[0].axvline(f_max, color="black", linestyle="--") + ax[1].axvline(f_min, color="black", linestyle="--") + ax[1].axvline(f_max, color="black", linestyle="--") + + ax[1].fill_between(f, rat_quantiles[0], rat_quantiles[4], color=colors, alpha=0.25) + ax[1].fill_between(f, rat_quantiles[1], rat_quantiles[3], color=colors, alpha=0.5) + ax[1].plot(f, rat_quantiles[2], color="black", lw=1) + ax[1].update( + { + "xscale": "log", + "yscale": "linear", + "xlabel": "Frequency", + "ylabel": r"${P_{\text{approx}}}/{P_{\text{true}}}$", + } + ) + + legend_elements = [ + Line2D([0], [0], color=colors, lw=2, label="Median"), + Line2D([0], [0], color="k", lw=1, ls="--", label=r"$f_\mathrm{min}, f_\mathrm{max}$"), + Patch(facecolor=colors, edgecolor=colors, alpha=0.25, label="95%"), + Patch(facecolor=colors, edgecolor=colors, alpha=0.5, label="68%"), + ] + ax[1].legend(handles=legend_elements, ncol=3, bbox_to_anchor=(1.0, -0.4)) + fig.align_ylabels(ax) + # fig.tight_layout() + return fig, ax + + +def SHO_power_spectrum(f, A, f0): + """Power spectrum of a stochastic harmonic oscillator. + + Parameters + ---------- + f : jax.Array + Frequency array. + A : float + Amplitude. + f0 : float + Position. + """ + P = A / (1 + jnp.power((f / f0), 4)) + + return P + + +def _get_coefficients_approximation( + kernel_type, kernel_params, f_min, f_max, n_approx_components=20, approximate_with="SHO" +): + """ + Get the coefficients of the approximation of the power law kernel + with a sum of SHO kernels or a sum of DRW+SHO kernels. + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process + Only designed for the following Power spectra ["PowL","DoubPowL] + kernel_params: dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + f_min: float + The minimum frequency of the approximation grid. + Should be the lowest frequency of the power spectrum + f_max: float + The maximum frequency of the approximation grid. + Should be the highest frequency of the power spectrum + n_approx_components: int + The number of components to use to approximate the power law + approximate_with: string + The type of kernel to use to approximate the power law power spectra + """ + # grid of frequencies for the approximation + spectral_points = jnp.geomspace(f_min, f_max, n_approx_components) + # build the matrix of the approximation + if approximate_with == "SHO": + spectral_matrix = 1 / ( + 1 + jnp.power(jnp.atleast_2d(spectral_points).T / spectral_points, 4) + ) + else: + raise NotImplementedError(f"Approximation {approximate_with} not implemented") + + # get the psd values and normalize them to the first value + psd_values = _psd_model(kernel_type, kernel_params)(spectral_points) + psd_values /= psd_values[0] + # compute the coefficients of the approximation + spectral_coefficients = jnp.linalg.solve(spectral_matrix, psd_values) + return spectral_points, spectral_coefficients + + +def get_psd_approx_samples( + f, kernel_type, kernel_params, f_min, f_max, n_approx_components=20, approximate_with="SHO" +): + """Get the true PSD model and the approximated PSD using SHO decomposition. + Parameters + ---------- + f : jax.Array + Frequency array. + kernel_type: string + The type of kernel to be used for the Gaussian Process + Only designed for the following Power spectra ["PowL","DoubPowL] + kernel_params: dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + f_min: float + The minimum frequency of the approximation grid. + f_max: float + The maximum frequency of the approximation grid. + n_approx_components: int + The number of components to use to approximate the power law + must be greater than 2, default 20 + approximate_with: string + The type of kernel to use to approximate the power law power spectra + Default is "SHO" + """ + f_c, a = _get_coefficients_approximation( + kernel_type, + kernel_params, + f_min, + f_max, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + ) + psd_SHO = SHO_power_spectrum(f, a[..., None], f_c[..., None]).sum(axis=0) -def get_kernel(kernel_type, kernel_params): + psd_model = _psd_model(kernel_type, kernel_params)(f) + psd_model /= psd_model[..., 0, None] + return psd_model, psd_SHO + + +def _approximate_powerlaw( + kernel_type, kernel_params, f_min, f_max, n_approx_components=20, approximate_with="SHO" +): + """ + Approximate the power law kernel with a sum of SHO kernels or a sum of DRW+SHO kernels. + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process + Only designed for the following Power spectra ["PowL","DoubPowL] + kernel_params: dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + f_min: float + The minimum frequency of the approximation grid. + Should be the lowest frequency of the power spectrum + f_max: float + The maximum frequency of the approximation grid. + Should be the highest frequency of the power spectrum + n_approx_components: int + The number of components to use to approximate the power law + approximate_with: string + The type of kernel to use to approximate the power law power spectra + Default is "SHO" + """ + spectral_points, spectral_coefficients = _get_coefficients_approximation( + kernel_type, + kernel_params, + f_min, + f_max, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + ) + + if approximate_with == "SHO": + amplitudes = ( + spectral_coefficients + * spectral_points + * kernel_params["variance"] + / jnp.sum(spectral_coefficients * spectral_points) + ) + + kernel = amplitudes[0] * kernels.quasisep.SHO( + quality=1 / jnp.sqrt(2), omega=2 * jnp.pi * spectral_points[0] + ) + for j in range(1, n_approx_components): + kernel += amplitudes[j] * kernels.quasisep.SHO( + quality=1 / jnp.sqrt(2), omega=2 * jnp.pi * spectral_points[j] + ) + return kernel + else: + raise NotImplementedError(f"Approximation {approximate_with} not implemented") + + +def _psd_model(kernel_type, kernel_params): + """Returns the power spectrum model for the given kernel type and parameters + + Parameters + ---------- + kernel_type: string + The type of kernel to be used for the Gaussian Process + Only designed for the following Power spectra ["PowL","DoubPowL] + kernel_params: dict + Dictionary containing the parameters for the kernel + Should contain the parameters for the selected kernel + """ + if kernel_type == "PowL": + return lambda f: jnp.power(f / kernel_params["f_bend"], -kernel_params["alpha_1"]) / ( + 1 + + jnp.power( + f / kernel_params["f_bend"], kernel_params["alpha_2"] - kernel_params["alpha_1"] + ) + ) + elif kernel_type == "DoubPowL": + return ( + lambda f: jnp.power(f / kernel_params["f_bend_1"], -kernel_params["alpha_1"]) + / ( + 1 + + jnp.power( + f / kernel_params["f_bend_1"], + kernel_params["alpha_2"] - kernel_params["alpha_1"], + ) + ) + / ( + 1 + + jnp.power( + f / kernel_params["f_bend_2"], + kernel_params["alpha_3"] - kernel_params["alpha_2"], + ) + ) + ) + else: + raise ValueError("PSD type not implemented") + + +def get_kernel( + kernel_type, + kernel_params, + f_min=0, + f_max=0, + n_approx_components=20, + approximate_with="SHO", + S_low=20, + S_high=20, +): """ Function for producing the kernel for the Gaussian Process. Returns the selected Tinygp kernel for the given parameters. @@ -50,11 +650,22 @@ def get_kernel(kernel_type, kernel_params): kernel_type: string The type of kernel to be used for the Gaussian Process To be selected from the kernels already implemented: - ["RN", "QPO", "QPO_plus_RN"] + ["RN", "QPO", "QPO_plus_RN","PowL","DoubPowL"] kernel_params: dict Dictionary containing the parameters for the kernel Should contain the parameters for the selected kernel + f_min: float + The minimum frequency of the time series. + f_max: float + The maximum frequency of the time series. + + n_approx_components: int + The number of components to use to approximate the power law + must be greater than 2, default 20 + approximate_with: string + The type of kernel to use to approximate the power law power spectra + Default is "SHO" """ if not jax_avail: @@ -86,6 +697,19 @@ def get_kernel(kernel_type, kernel_params): d=2 * jnp.pi * kernel_params["freq"], ) return kernel + elif kernel_type == "PowL" or kernel_type == "DoubPowL": + if n_approx_components < 2: + raise ValueError("Number of approximation components must be greater than 2") + + kernel = _approximate_powerlaw( + kernel_type, + kernel_params, + f_min=f_min / S_low, + f_max=f_max * S_high, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + ) + return kernel else: raise ValueError("Kernel type not implemented") @@ -365,6 +989,10 @@ def _get_kernel_params(kernel_type): return ["log_arn", "log_crn", "log_aqpo", "log_cqpo", "log_freq"] elif kernel_type == "QPO": return ["log_aqpo", "log_cqpo", "log_freq"] + elif kernel_type == "PowL": + return ["alpha_1", "log_f_bend", "alpha_2", "variance"] + elif kernel_type == "DoubPowL": + return ["alpha_1", "log_f_bend_1", "alpha_2", "log_f_bend_2", "alpha_3", "variance"] else: raise ValueError("Kernel type not implemented") @@ -395,7 +1023,7 @@ def _get_mean_params(mean_type): raise ValueError("Mean type not implemented") -def get_gp_params(kernel_type, mean_type): +def get_gp_params(kernel_type, mean_type, scale_errors=False, log_transform=False): """ Generates a list of the parameters for the GP model based on the kernel and mean type. To be used to set the order of the parameters for `get_prior` and `get_likelihood` functions. @@ -404,12 +1032,18 @@ def get_gp_params(kernel_type, mean_type): ---------- kernel_type: string The type of kernel to be used for the Gaussian Process model: - ["RN", "QPO", "QPO_plus_RN"] + ["RN", "QPO", "QPO_plus_RN", "PowL", "DoubPowL"] mean_type: string The type of mean to be used for the Gaussian Process model: ["gaussian", "exponential", "constant", "skew_gaussian", "skew_exponential", "fred"] + scale_errors: bool, default False + Whether to include a scale parameter on the errors in the GP model + log_transform: bool, default False + Whether to take the log of the data to make the data normally distributed + This will add a parameter to the model to model a shifted log normal distribution + And will change the mean parameter "log_A" to "A" as the mean could be negative Returns ------- @@ -423,6 +1057,12 @@ def get_gp_params(kernel_type, mean_type): kernel_params = _get_kernel_params(kernel_type) mean_params = _get_mean_params(mean_type) kernel_params.extend(mean_params) + if scale_errors: + kernel_params.append("scale_err") + if log_transform: + kernel_params.append("log_shift") + if "log_A" in kernel_params: + kernel_params[kernel_params.index("log_A")] = "A" return kernel_params @@ -497,7 +1137,18 @@ def prior_model(): return prior_model -def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwargs): +def get_log_likelihood( + params_list, + kernel_type, + mean_type, + times, + counts, + counts_err=None, + S_low=20, + S_high=20, + n_approx_components=20, + approximate_with="SHO", +): """ A log likelihood generator function based on given values. Makes a jaxns specific log likelihood function which takes in the @@ -516,7 +1167,7 @@ def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwa kernel_type: The type of kernel to be used in the model: - ["RN", "QPO", "QPO_plus_RN"] + ["RN", "QPO", "QPO_plus_RN", "PowL", "DoubPowL"] mean_type: The type of mean to be used in the model: @@ -529,6 +1180,20 @@ def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwa counts: np.array or jnp.array The photon counts array of the lightcurve + counts_err: np.array or jnp.array + The photon counts error array of the lightcurve + + n_approx_components: int + The number of components to use to approximate the power law + must be greater than 2, default 20 + + approximate_with: string + The type of kernel to use to approximate the power law power spectra + Default is "SHO" + log_transform: bool, default False + Whether to take the log of the data to make the data normally distributed + This will add a parameter to the model to model a shifted log normal distribution + Returns ------- The Jaxns specific log likelihood function. @@ -539,6 +1204,7 @@ def get_log_likelihood(params_list, kernel_type, mean_type, times, counts, **kwa if not can_make_gp: raise ImportError("Tinygp is required to make the GP model.") + f_min, f_max = 1 / (times[-1] - times[0]), 0.5 / jnp.min(jnp.diff(times)) @jit def likelihood_model(*args): @@ -548,10 +1214,38 @@ def likelihood_model(*args): param_dict[params[4:]] = jnp.exp(args[i]) else: param_dict[params] = args[i] - kernel = get_kernel(kernel_type=kernel_type, kernel_params=param_dict) + + kernel = get_kernel( + kernel_type=kernel_type, + kernel_params=param_dict, + f_min=f_min, + f_max=f_max, + n_approx_components=n_approx_components, + approximate_with=approximate_with, + S_low=S_low, + S_high=S_high, + ) mean = get_mean(mean_type=mean_type, mean_params=param_dict) - gp = GaussianProcess(kernel, times, mean_value=mean(times)) - return gp.log_probability(counts) + if "shift" in param_dict.keys(): + x = jnp.log(counts - param_dict["shift"]) + if counts_err is not None: + xerr = jnp.divide(counts_err, counts - param_dict["shift"]) + else: + x = counts + xerr = counts_err + + if counts_err is None: + gp = GaussianProcess(kernel, times, mean_value=mean(times)) + elif counts_err is not None and "scale_err" in param_dict.keys(): + gp = GaussianProcess( + kernel, + times, + mean_value=mean(times), + diag=param_dict["scale_err"] * jnp.square(xerr), + ) + else: + gp = GaussianProcess(kernel, times, mean_value=mean(times), diag=jnp.square(xerr)) + return gp.log_probability(x) return likelihood_model @@ -617,14 +1311,16 @@ def sample(self, prior_model=None, likelihood_model=None, max_samples=1e4, num_l nsmodel = Model(prior_model=self.prior_model, log_likelihood=self.log_likelihood_model) nsmodel.sanity_check(random.PRNGKey(10), S=100) - self.exact_ns = ExactNestedSampler( - nsmodel, num_live_points=num_live_points, max_samples=max_samples + # check the approximation of the model + + self.exact_ns = DefaultNestedSampler( + nsmodel, num_live_points=num_live_points, max_samples=max_samples, verbose=True ) termination_reason, state = self.exact_ns( - random.PRNGKey(42), term_cond=TerminationCondition(live_evidence_frac=1e-4) + random.PRNGKey(42), term_cond=TerminationCondition() # live_evidence_frac=1e-4) ) - self.results = self.exact_ns.to_results(state, termination_reason) + self.results = self.exact_ns.to_results(termination_reason, state) print("Simulation Complete") def get_evidence(self): diff --git a/stingray/modeling/tests/test_gpmodeling.py b/stingray/modeling/tests/test_gpmodeling.py index 56a2c9a2b..76f34fc76 100644 --- a/stingray/modeling/tests/test_gpmodeling.py +++ b/stingray/modeling/tests/test_gpmodeling.py @@ -3,6 +3,21 @@ import numpy as np import matplotlib.pyplot as plt +from stingray.modeling.gpmodeling import ( + get_kernel, + get_mean, + get_gp_params, + _psd_model, + get_psd_and_approx, + run_prior_checks, + _get_coefficients_approximation, + get_prior, + get_log_likelihood, + GPResult, + get_priors_samples, +) +from stingray import Lightcurve + try: import jax import jax.numpy as jnp @@ -22,9 +37,6 @@ except ImportError: _HAS_TINYGP = False -from stingray.modeling.gpmodeling import get_kernel, get_mean, get_gp_params -from stingray.modeling.gpmodeling import get_prior, get_log_likelihood, GPResult -from stingray import Lightcurve try: import tensorflow_probability.substrates.jax as tfp @@ -46,7 +58,7 @@ def clear_all_figs(): plt.close(fig) -@pytest.mark.xfail +# @pytest.mark.xfail @pytest.mark.skipif(not _HAS_TINYGP, reason="tinygp not installed") class Testget_kernel(object): def setup_class(self): @@ -87,6 +99,52 @@ def test_get_kernel_qpo(self): kernel_qpo(self.x, jnp.array([0.0])) == kernel_qpo_test(self.x, jnp.array([0.0])) ).all() + def test_get_kernel_powL(self): + + t = jnp.arange(0, 100.05, 0.05) + f_min, f_max = 1 / (t[-1] - t[0]) / 20, 1 / (2 * (t[1] - t[0])) * 20 + n_approx_components = 20 + alpha_1, f_1, alpha_2, variance = 0.3, 0.05, 2.9, 0.2 + + kernel_params = { + "alpha_1": alpha_1, + "f_bend": f_1, + "alpha_2": alpha_2, + "variance": variance, + } + + spectral_points = jnp.geomspace(f_min, f_max, 20) + spectral_matrix = 1 / ( + 1 + jnp.power(jnp.atleast_2d(spectral_points).T / spectral_points, 4) + ) + psd_values = _psd_model("PowL", kernel_params)(spectral_points) + psd_values /= psd_values[0] + spectral_coefficients = jnp.linalg.solve(spectral_matrix, psd_values) + amplitudes = ( + spectral_coefficients + * spectral_points + * kernel_params["variance"] + / jnp.sum(spectral_coefficients * spectral_points) + ) + kernel = amplitudes[0] * kernels.quasisep.SHO( + quality=1 / jnp.sqrt(2), omega=2 * jnp.pi * spectral_points[0] + ) + for j in range(1, n_approx_components): + kernel += amplitudes[j] * kernels.quasisep.SHO( + quality=1 / jnp.sqrt(2), omega=2 * jnp.pi * spectral_points[j] + ) + + kernel_PowL_test = get_kernel( + "PowL", + kernel_params, + f_min=1 / (t[-1] - t[0]), + f_max=0.5 / (t[1] - t[0]), + n_approx_components=n_approx_components, + ) + assert ( + kernel(self.x, jnp.array([0.0])) == kernel_PowL_test(self.x, jnp.array([0.0])) + ).all() + def test_value_error(self): with pytest.raises(ValueError, match="Kernel type not implemented"): get_kernel("periodic", self.kernel_params) @@ -235,6 +293,129 @@ def test_get_qpo(self): "log_sig", ] + def test_get_params_PowL(self): + assert get_gp_params("PowL", "constant") == [ + "alpha_1", + "log_f_bend", + "alpha_2", + "variance", + "log_A", + ] + assert get_gp_params("PowL", "constant", scale_errors=True) == [ + "alpha_1", + "log_f_bend", + "alpha_2", + "variance", + "log_A", + "scale_err", + ] + + def test_get_params_DoubPowL(self): + assert get_gp_params("DoubPowL", "constant") == [ + "alpha_1", + "log_f_bend_1", + "alpha_2", + "log_f_bend_2", + "alpha_3", + "variance", + "log_A", + ] + assert get_gp_params("DoubPowL", "constant", scale_errors=True) == [ + "alpha_1", + "log_f_bend_1", + "alpha_2", + "log_f_bend_2", + "alpha_3", + "variance", + "log_A", + "scale_err", + ] + + +@pytest.mark.skipif( + not (_HAS_TINYGP and _HAS_TFP and _HAS_JAXNS), reason="tinygp, tfp or jaxns not installed" +) +class TestPSDapprox(object): + def setup_class(self): + self.t = np.linspace(0, 1, 10) + self.y = np.array([2.4, 3.5, 4.5, 5.6, 6.7, 7.8, 8.9, 9.0, 10.1, 11.2]) + self.yerr = np.array([0.1, 0.3, 0.2, 0.1, 0.2, 0.1, 0.3, 0.2, 0.1, 0.2]) + + self.kernel_type = "PowL" + self.mean_type = "constant" + self.kernel_params = get_gp_params( + self.kernel_type, self.mean_type, scale_errors=True, log_transform=True + ) + self.loglike = get_log_likelihood( + self.kernel_params, self.kernel_type, self.mean_type, self.t, self.y, self.yerr + ) + + min_f_1, max_f_1 = 5.5e-3, 0.5 + muL = -0.2 + + self.priors = [ + tfpd.Uniform(low=0.0, high=1.25), + tfpd.Uniform(low=jnp.log(min_f_1), high=jnp.log(max_f_1)), + tfpd.Uniform(low=1.5, high=4), + tfpd.LogNormal(muL, 1.0), + tfpd.Normal(0.0, 2), + tfpd.Gamma(jnp.array(2.0), jnp.array(2.0)), # was not working without jnp.array, why? + tfpd.Uniform(low=jnp.log(1e-6), high=jnp.log(0.99 * 3)), + ] + + def test_get_prior_samples(self): + + loglike = get_log_likelihood( + self.kernel_params, self.kernel_type, self.mean_type, self.t, self.y, self.yerr + ) + + # prior_dict = dict(zip(self.kernel_params, priors)) + # prior_model = get_prior(self.kernel_params, prior_dict) + prior_samples = get_priors_samples( + jax.random.PRNGKey(0), self.kernel_params, self.priors, loglike, 10 + ) + assert len(prior_samples) == len(self.kernel_params) + for key in self.kernel_params: + assert prior_samples[key].shape == (10,) + + def test_get_psd_and_approx(self): + f0, fM = 5.5e-3 / 20, 0.5 * 20 + loglike = get_log_likelihood( + self.kernel_params, self.kernel_type, self.mean_type, self.t, self.y, self.yerr + ) + + prior_samples = get_priors_samples( + jax.random.PRNGKey(0), self.kernel_params, self.priors, loglike, 10 + ) + + f, psd_models, psd_approx = get_psd_and_approx( + self.kernel_type, self.kernel_params, prior_samples, f0, fM, n_frequencies=200 + ) + assert len(f) == 200 + assert psd_models.shape == (10, 200) + assert psd_approx.shape == (10, 200) + + def test_run_prior_checks(self): + loglike = get_log_likelihood( + self.kernel_params, self.kernel_type, self.mean_type, self.t, self.y, self.yerr + ) + + fig1, fig2 = run_prior_checks( + self.kernel_type, self.kernel_params, self.priors, loglike, 5.5e-2, 0.5 + ) + plt.fignum_exists(1) + plt.fignum_exists(2) + + def test__get_coefficients_approximation(self): + + kernel_params = {"alpha_1": 0.3, "f_bend": 0.05, "alpha_2": 3.5, "variance": 0.2} + f_min, f_max = 1 / (self.t[-1] - self.t[0]) / 20, 1 / (2 * np.min(np.diff(self.t))) * 20 + spectral_points, spectral_coefs = _get_coefficients_approximation( + self.kernel_type, kernel_params, f_min, f_max, 25 + ) + assert spectral_points.shape == (25,) + assert spectral_coefs.shape == (25,) + @pytest.mark.xfail @pytest.mark.skipif(