diff --git a/src/discovery/deterministic.py b/src/discovery/deterministic.py index cdac215..ac969ff 100644 --- a/src/discovery/deterministic.py +++ b/src/discovery/deterministic.py @@ -277,4 +277,245 @@ def fourier_binary(f, df, mintoa, pos, log10_h0, log10_f0, ra, sindec, cosinc, p if not pulsarterm: fourier_binary = functools.partial(fourier_binary, phi_psr=jnp.nan) - return fourier_binary \ No newline at end of file + return fourier_binary + + +def chromatic_exponential(psr, fref=1400.0): + r""" + Factory function for chromatic exponential delay model. + + Creates a delay function that models chromatic exponential events (e.g., glitches, + state changes) with frequency-dependent amplitude scaling. + + Parameters + ---------- + psr : Pulsar + Pulsar object containing toas and freqs attributes + fref : float, optional + Reference frequency in MHz for normalization (default: 1400.0) + + Returns + ------- + delay : callable + Function with signature (t0, log10_Amp, log10_tau, sign_param, alpha) -> ndarray + Computes chromatic exponential delay: + + .. math:: + + \Delta(t) = \pm A_0 \exp\left(-\frac{t - t_0}{\tau}\right) \left(\frac{f_{\text{ref}}}{f}\right)^\alpha H(t - t_0) + + where :math:`H(t - t_0)` is the Heaviside step function. + """ + toas, fnorm = matrix.jnparray(psr.toas / const.day), matrix.jnparray(fref / psr.freqs) + + def delay(t0, log10_Amp, log10_tau, sign_param, alpha): + r""" + Compute chromatic exponential delay. + + .. math:: + + \Delta(t) = \pm A_0 \exp\left(-\frac{t - t_0}{\tau}\right) \left(\frac{f_{\text{ref}}}{f}\right)^\alpha H(t - t_0) + + Parameters + ---------- + t0 : float + Event epoch :math:`t_0` in days (MJD) + log10_Amp : float + Log10 of amplitude :math:`A_0` in seconds + log10_tau : float + Log10 of exponential decay timescale :math:`\tau` in days + sign_param : float + Sign of the delay (positive or negative) + alpha : float + Chromatic index :math:`\alpha` (spectral index for frequency dependence) + + Returns + ------- + delay : ndarray + Array of timing residuals :math:`\Delta(t)` in seconds with shape matching psr.toas + """ + return jnp.sign(sign_param) * 10**log10_Amp * jnp.exp(- (toas - t0) / 10**log10_tau) * fnorm**alpha * jnp.heaviside(toas - t0, 1.0) + + delay.__name__ = "chromatic_exponential_delay" + return delay + + +def chromatic_annual(psr, fref=1400.0): + r""" + Factory function for chromatic annual delay model. + + Creates a delay function that models chromatic annual sinusoidal variations + (e.g., annual DM variations) with frequency-dependent amplitude scaling. + + Parameters + ---------- + psr : Pulsar + Pulsar object containing toas and freqs attributes + fref : float, optional + Reference frequency in MHz for normalization (default: 1400.0) + + Returns + ------- + delay : callable + Function with signature (log10_Amp, phase, alpha) -> ndarray + Computes chromatic annual delay: + + .. math:: + + \Delta(t) = A_0 \sin(2\pi f_{\text{yr}} t + \phi) \left(\frac{f_{\text{ref}}}{f}\right)^\alpha + + where :math:`f_{\text{yr}}` is the annual frequency (1/year). + """ + toas, fnorm = matrix.jnparray(psr.toas), matrix.jnparray(fref / psr.freqs) + + def delay(log10_Amp, phase, alpha): + r""" + Compute chromatic annual delay. + + .. math:: + + \Delta(t) = A_0 \sin(2\pi f_{\text{yr}} t + \phi) \left(\frac{f_{\text{ref}}}{f}\right)^\alpha + + Parameters + ---------- + log10_Amp : float + Log10 of amplitude :math:`A_0` in seconds + phase : float + Phase offset :math:`\phi` in radians + alpha : float + Chromatic index :math:`\alpha` (spectral index for frequency dependence) + + Returns + ------- + delay : ndarray + Array of timing residuals :math:`\Delta(t)` in seconds with shape matching psr.toas + """ + return 10**log10_Amp * jnp.sin(2*jnp.pi * const.fyr * toas + phase) * fnorm**alpha + + delay.__name__ = "chromatic_annual_delay" + return delay + + +def chromatic_gaussian(psr, fref=1400.0): + r""" + Factory function for chromatic Gaussian delay model. + + Creates a delay function that models chromatic Gaussian events (e.g., transient + DM variations, localized events) with frequency-dependent amplitude scaling. + + Parameters + ---------- + psr : Pulsar + Pulsar object containing toas and freqs attributes + fref : float, optional + Reference frequency in MHz for normalization (default: 1400.0) + + Returns + ------- + delay : callable + Function with signature (t0, log10_Amp, log10_sigma, sign_param, alpha) -> ndarray + Computes chromatic Gaussian delay: + + .. math:: + + \Delta(t) = \pm A_0 \exp\left(-\frac{(t - t_0)^2}{2\sigma^2}\right) \left(\frac{f_{\text{ref}}}{f}\right)^\alpha + """ + toas, fnorm = matrix.jnparray(psr.toas / const.day), matrix.jnparray(fref / psr.freqs) + + def delay(t0, log10_Amp, log10_sigma, sign_param, alpha): + r""" + Compute chromatic Gaussian delay. + + .. math:: + + \Delta(t) = \pm A_0 \exp\left(-\frac{(t - t_0)^2}{2\sigma^2}\right) \left(\frac{f_{\text{ref}}}{f}\right)^\alpha + + Parameters + ---------- + t0 : float + Event epoch :math:`t_0` in days (MJD) + log10_Amp : float + Log10 of amplitude :math:`A_0` in seconds + log10_sigma : float + Log10 of Gaussian width :math:`\sigma` in days + sign_param : float + Sign of the delay (positive or negative) + alpha : float + Chromatic index :math:`\alpha` (spectral index for frequency dependence) + + Returns + ------- + delay : ndarray + Array of timing residuals :math:`\Delta(t)` in seconds with shape matching psr.toas + """ + return jnp.sign(sign_param) * 10**log10_Amp * jnp.exp(-(toas - t0)**2 / (2 * (10**log10_sigma)**2)) * fnorm**alpha + + delay.__name__ = "chromatic_gaussian_delay" + return delay + + +def orthometric_shapiro(psr, binphase): + r""" + Factory function for orthometric Shapiro delay model. + + Creates a delay function that models Shapiro delay in binary pulsars using + the orthometric parameterization from Freire & Wex (2010). + + Parameters + ---------- + psr : Pulsar + Pulsar object containing toas attribute + binphase : array-like + Binary orbital phase :math:`\Phi` at each TOA (same shape as psr.toas) + + Returns + ------- + delay : callable + Function with signature (h3, stig) -> ndarray + Computes orthometric Shapiro delay (Equation 29 in Freire & Wex 2010): + + .. math:: + + \Delta_s = -\frac{2 h_3}{\zeta^3} \log(1 + \zeta^2 - 2 \zeta \sin\Phi) + + Raises + ------ + ValueError + If binphase shape does not match psr.toas shape + + References + ---------- + Freire, P. C. C., & Wex, N. (2010). The orthometric parametrization of the + Shapiro delay and an improved test of general relativity with binary pulsars. + MNRAS, 409(1), 199-212. + """ + toas, binphase = matrix.jnparray(psr.toas / const.day), matrix.jnparray(binphase) + if not np.shape(binphase) == np.shape(toas): + raise ValueError("Input binphase must have the same shape as toas") + + def delay(h3, stig): + r""" + Compute orthometric Shapiro delay. + + Implements Equation (29) from Freire & Wex (2010): + + .. math:: + + \Delta_s = -\frac{2 h_3}{\zeta^3} \log(1 + \zeta^2 - 2 \zeta \sin\Phi) + + Parameters + ---------- + h3 : float + Orthometric amplitude parameter :math:`h_3` (related to companion mass and inclination) + stig : float + Orthometric shape parameter :math:`\zeta` (related to orbital inclination) + + Returns + ------- + delay : ndarray + Shapiro timing delay :math:`\Delta_s` in seconds with shape matching psr.toas + """ + return -(2.0 * h3 / stig**3) * jnp.log(1 + stig**2 - 2 * stig * jnp.sin(binphase)) + + delay.__name__ = "orthometric_shapiro_delay" + return delay \ No newline at end of file diff --git a/src/discovery/prior.py b/src/discovery/prior.py index 25ce6e2..c3bb69a 100644 --- a/src/discovery/prior.py +++ b/src/discovery/prior.py @@ -24,6 +24,8 @@ def logpriorfunc(params): "(.*_)?red_noise_log10_A.*": [-20, -11], # deprecated "(.*_)?red_noise_gamma.*": [0, 7], # deprecated "(.*_)?red_noise_log10_fb": [-9, -6], + "(.*_)?sw_gp_log10_A": [-10, -2], + "(.*_)?sw_gp_gamma": [0, 4], "crn_log10_A.*": [-18, -11], "crn_gamma.*": [0, 7], "crn_log10_fb": [-9, -6], @@ -46,7 +48,22 @@ def logpriorfunc(params): "cw_log10_f0": [-9.0, -7.0], "cw_log10_h0": [-18.0, -11.0], "cw_phi_earth": [0., 2*np.pi], - "(.*_)?cw_phi_psr": [0., 2*np.pi] + "(.*_)?cw_phi_psr": [0., 2*np.pi], + "(.*_)?chrom_exp_t0": [50000, 65000], + "(.*_)?chrom_exp_log10_Amp": [-10, -4], + "(.*_)?chrom_exp_log10_tau": [0, 4], + "(.*_)?chrom_exp_sign_param": [-1, 1], + "(.*_)?chrom_exp_alpha": [0, 7], + "(.*_)?chrom_1yr_log10_Amp": [-10, -4], + "(.*_)?chrom_1yr_phase": [0, 2 * np.pi], + "(.*_)?chrom_1yr_alpha": [0, 7], + "(.*_)?chrom_gauss_t0": [50000, 65000], + "(.*_)?chrom_gauss_log10_Amp": [-10, -4], + "(.*_)?chrom_gauss_log10_sigma": [0, 4], + "(.*_)?chrom_gauss_sign_param": [-1, 1], + "(.*_)?chrom_gauss_alpha": [0, 7], + "(.*_)?h3": [0.0, 10**-5], + "(.*_)?stig": [0.0, 1.0] } def getprior_uniform(par, priordict={}): diff --git a/src/discovery/solar.py b/src/discovery/solar.py index 3764c39..bf4a438 100644 --- a/src/discovery/solar.py +++ b/src/discovery/solar.py @@ -1,5 +1,6 @@ import numpy as np -import functools +import inspect +import jax.numpy as jnp from . import const from . import matrix @@ -7,16 +8,6 @@ AU_light_sec = const.AU / const.c # 1 AU in light seconds AU_pc = const.AU / const.pc # 1 AU in parsecs (for DM normalization) -def make_solardmfourierbasis(psr, components, T=None): - """ - for a fourier basis for the solar wind gp. this assumes - that you include a deterministic solar wind delay, which samples - n_earth, and here these are stochastic fluctuations on top of that. - """ - f, df, fmat = ds.fourierbasis(psr, components, T) - shape = ds.make_solardm(psr) - return f, df, fmat * shape(1.)[:, None] - def theta_impact(psr): """From enterprise_extensions: use the attributes of an Enterprise Pulsar object to calculate the solar impact angle. @@ -47,23 +38,181 @@ def solardm(n_earth): return solardm +def _dm_solar_close(n_earth, r_earth): + return (n_earth * AU_light_sec * AU_pc / r_earth) + + +def _dm_solar(n_earth, theta, r_earth): + return ((np.pi - theta) * + (n_earth * AU_light_sec * AU_pc + / (r_earth * np.sin(theta)))) + +def dm_solar(n_earth, theta, r_earth): + """ + Calculate dispersion measure from a 1/r^2 solar wind density model. + + This function computes the integrated column density of free electrons + along the line of sight through the solar wind, assuming a spherically + symmetric 1/r^2 density profile. The calculation uses different approximations + depending on the solar elongation angle to avoid numerical issues. + + Parameters + ---------- + n_earth : float or ndarray + Solar wind proton/electron density at Earth's orbit (cm^-3). + theta : float or ndarray + Solar elongation angle between the Sun and the line of sight to the + pulsar (radians). theta = 0 corresponds to the Sun directly in the + line of sight, theta = pi/2 is at right angles. + r_earth : float or ndarray + Distance from Earth to Sun (light seconds). + + Returns + ------- + float or ndarray + Dispersion measure contribution from solar wind (pc cm^-3). + + Notes + ----- + For small elongation angles (pi - theta < 1e-5), the function uses a + close approach approximation to avoid numerical instabilities. Otherwise, + it uses the full integral formula from the 1/r^2 density model. + + References + ---------- + .. [1] You, X. P., Hobbs, G., Coles, W. A., et al. 2007, MNRAS, 378, 493 + "Dispersion measure variations and their effect on precision pulsar timing" + https://doi.org/10.1111/j.1365-2966.2007.11617.x + """ + return matrix.jnp.where(np.pi - theta >= 1e-5, + _dm_solar(n_earth, theta, r_earth), + _dm_solar_close(n_earth, r_earth)) + +def fourierbasis_solar_dm(psr, + components, + T=None): + """ + Construct a Fourier design matrix for solar wind dispersion measure variations. + + + Parameters + ---------- + psr : :class:`pulsar.Pulsar` + Discovery Pulsar object containing TOAs, frequencies, and solar system + ephemeris information. + components : int + Number of Fourier components to include in the model. + T : float, optional + Total timespan of the data in seconds. If None, will be computed from + the pulsar's TOAs. + + Returns + ------- + f : ndarray + Sampling frequencies for the Fourier components (Hz). + df : float + Frequency spacing between components (Hz). + F : ndarray + Solar wind DM-variation Fourier design matrix with shape (n_toas, 2*components), + where each TOA is weighted by the frequency-dependent solar wind DM delays. + + Notes + ----- + This function is adapted from enterprise_extensions. The design matrix is + constructed by first obtaining a standard Fourier basis, then scaling each + TOA by the solar wind DM signature computed from the 1/r^2 solar wind density + model. + + Examples + -------- + Create a Gaussian process model for solar wind DM variations using a powerlaw prior: + + >>> from discovery import solar, signals + >>> # Create a solar wind DM GP with 30 Fourier components + >>> gp = signals.makegp_fourier( + ... psr, + ... signals.powerlaw, + ... components=30, + ... fourierbasis=solar.fourierbasis_solar_dm, + ... name='solar_wind_dm' + ... ) + """ + + # Lazy import to avoid circular dependency + from .signals import fourierbasis + + # get base Fourier design matrix and frequencies + f, df, fmat = fourierbasis(psr, components, T) + theta, R_earth, _, _ = theta_impact(psr) + dm_sol_wind = dm_solar(1.0, theta, R_earth) + dt_DM = dm_sol_wind * 4.148808e3 / (psr.freqs**2) + + return f, df, fmat * dt_DM[:, None] + +def makegp_timedomain_solar_dm(psr, covariance, dt=1.0, common=[], name='timedomain_sw_gp'): + """ + Construct a time-domain Gaussian process for solar wind dispersion measure variations. + + This function builds a GP model for solar wind-induced DM variations by combining + a covariance function in the time domain with a model for the solar wind geometry. + The TOAs are quantized into time bins, and the GP is constructed using the time separations + between bins weighted by the solar wind DM signature. + + Parameters + ---------- + psr : :class:`pulsar.Pulsar` + Discovery Pulsar object containing TOAs, frequencies, and solar system + ephemeris information. + covariance : callable + Function that returns the time domain autocorrelation for a given + separation (tau). Must have signature `covariance(tau, *params)` where + tau is the time separation array. + dt : float, optional + Time bin width in seconds for quantizing TOAs. Default is 1.0. + common : list, optional + List of parameter names that should be treated as common (shared) across + pulsars rather than pulsar-specific. Default is []. + name : str, optional + Base name for the GP parameters. Used as prefix for parameter naming. + Default is 'timedomain_sw_gp'. + + Returns + ------- + :class:`matrix.VariableGP` + A matrix.VariableGP object containing the noise covariance matrix (as a + NoiseMatrix2D_var) and the design matrix (Umat) that maps the GP + to the TOA residuals via solar wind DM delays. See :class:`matrix.VariableGP` + for details. + + Notes + ----- + The design matrix Umat maps the low-rank GP (evaluated at quantized TOAs) + to the full TOA residuals, scaled by the frequency-dependent solar wind + DM signature. + """ + # Lazy import to avoid circular dependency + from .signals import quantize -def chromaticdelay(toas, freqs, t0, log10_Amp, log10_tau, idx): - toadays, invnormfreqs = toas / const.day, 1400.0 / freqs - dt = toadays - t0 + argspec = inspect.getfullargspec(covariance) + argmap = [(arg if arg in common else f'{name}_{arg}' if f'{name}_{arg}' in common else f'{psr.name}_{name}_{arg}') + for arg in argspec.args if arg not in ['tau']] - return matrix.jnp.where(dt > 0.0, -1.0 * (10**log10_Amp) * matrix.jnp.exp(-dt / (10**log10_tau)) * invnormfreqs**idx, 0.0) + # get solar wind ingredients + theta, R_earth, _, _ = theta_impact(psr) + dm_sol_wind = dm_solar(1.0, theta, R_earth) + dt_DM = dm_sol_wind * 4.148808e3 / (psr.freqs**2) -def make_chromaticdelay(psr, idx=None): - """From enterprise_extensions: pre-calculate chromatic exponential-dip delay.""" + bins = quantize(psr.toas, dt) + Umat = np.vstack([bins == i for i in range(bins.max() + 1)]).T.astype('d') + Umat = Umat * dt_DM[:, None] + toas = psr.toas @ Umat / Umat.sum(axis=0) - toadays, invnormfreqs = matrix.jnparray(psr.toas / const.day), matrix.jnparray(1400.0 / psr.freqs) + get_tmat = covariance + tau = jnp.abs(toas[:, jnp.newaxis] - toas[jnp.newaxis, :]) - def decay(t0, log10_Amp, log10_tau, idx): - dt = toadays - t0 - return matrix.jnp.where(dt > 0.0, -1.0 * (10**log10_Amp) * matrix.jnp.exp(-dt / (10**log10_tau)) * invnormfreqs**idx, 0.0) + def getphi(params): + return get_tmat(tau, *[params[arg] for arg in argmap]) + getphi.params = argmap - if idx is not None: - decay = functools.partial(decay, idx=idx) + return matrix.VariableGP(matrix.NoiseMatrix2D_var(getphi), Umat) - return decay diff --git a/tests/test_deterministic.py b/tests/test_deterministic.py new file mode 100644 index 0000000..8638f6f --- /dev/null +++ b/tests/test_deterministic.py @@ -0,0 +1,259 @@ +#!/usr/bin/env python3 +"""Tests for discovery.deterministic module""" + +import pytest +import numpy as np + +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp + +from discovery import deterministic + + +class MockPsr: + """Mock pulsar object for testing.""" + def __init__(self, toas=None, freqs=None): + # Default TOAs in MJ seconds (Modified Julian seconds, not MJD) + # 55000 MJD = 55000 * 86400 MJ seconds + self.toas = toas if toas is not None else np.array([55000.0 * 86400, 55001.0 * 86400, 55002.0 * 86400]) + self.freqs = freqs if freqs is not None else np.array([2800.0, 2800.0, 2800.0]) + + +class TestChromaticExponential: + """Tests for chromatic_exponential delay function.""" + + def test_chromatic_exponential_returns_callable(self): + """Test that chromatic_exponential returns a callable function.""" + psr = MockPsr() + delay_func = deterministic.chromatic_exponential(psr, fref=1400.0) + + assert callable(delay_func) + assert delay_func.__name__ == "chromatic_exponential_delay" + + def test_chromatic_exponential_output_shape(self): + """Test that output has correct shape.""" + psr = MockPsr() + delay_func = deterministic.chromatic_exponential(psr, fref=1400.0) + + result = delay_func(t0=55001.0, log10_Amp=-6.0, log10_tau=1.0, + sign_param=1.0, alpha=2.0) + + assert result.shape == psr.toas.shape + + def test_chromatic_exponential_heaviside(self): + """Test that delay is zero before t0 (Heaviside step).""" + psr = MockPsr() + delay_func = deterministic.chromatic_exponential(psr, fref=1400.0) + + result = delay_func(t0=55000.5, log10_Amp=-6.0, log10_tau=1.0, + sign_param=1.0, alpha=2.0) + + # Before t0 (first element), delay should be zero due to Heaviside + assert result[0] == 0.0 + # After t0, delay should be non-zero + assert result[1] != 0.0 + assert result[2] != 0.0 + + def test_chromatic_exponential_sign(self): + """Test that sign_param correctly changes sign of delay.""" + psr = MockPsr() + delay_func = deterministic.chromatic_exponential(psr, fref=1400.0) + + result_pos = delay_func(t0=55000.5, log10_Amp=-6.0, log10_tau=1.0, + sign_param=1.0, alpha=2.0) + result_neg = delay_func(t0=55000.5, log10_Amp=-6.0, log10_tau=1.0, + sign_param=-1.0, alpha=2.0) + + np.testing.assert_array_almost_equal(result_pos, -result_neg) + + def test_chromatic_exponential_frequency_dependence(self): + """Test frequency-dependent scaling with alpha parameter.""" + psr = MockPsr(toas=np.array([55001.0 * 86400, 55001.0 * 86400]), + freqs=np.array([2800.0, 1400.0])) + delay_func = deterministic.chromatic_exponential(psr, fref=1400.0) + + result = delay_func(t0=55000.0, log10_Amp=-6.0, log10_tau=1.0, + sign_param=1.0, alpha=2.0) + + # At alpha=2, delay should scale as (fref/f)^2 + # freq[0]=2800: (1400/2800)^2 = 0.25 + # freq[1]=1400: (1400/1400)^2 = 1.0 + # So result[1] should be 4x result[0] + # Check both values are finite first + assert np.isfinite(result[0]) + assert np.isfinite(result[1]) + ratio = result[1] / result[0] + assert np.abs(ratio - 4.0) < 1e-10 + + +class TestChromaticAnnual: + """Tests for chromatic_annual delay function.""" + + def test_chromatic_annual_returns_callable(self): + """Test that chromatic_annual returns a callable function.""" + psr = MockPsr() + delay_func = deterministic.chromatic_annual(psr, fref=1400.0) + + assert callable(delay_func) + assert delay_func.__name__ == "chromatic_annual_delay" + + def test_chromatic_annual_output_shape(self): + """Test that output has correct shape.""" + psr = MockPsr() + delay_func = deterministic.chromatic_annual(psr, fref=1400.0) + + result = delay_func(log10_Amp=-6.0, phase=0.0, alpha=2.0) + + assert result.shape == psr.toas.shape + + def test_chromatic_annual_sinusoidal(self): + """Test that delay follows sinusoidal pattern.""" + psr = MockPsr(toas=np.array([55000.0 * 86400]), + freqs=np.array([2800.0])) + delay_func = deterministic.chromatic_annual(psr, fref=1400.0) + + # At phase=0, should give sin(2*pi*f_yr*t) + # At phase=pi/2, should give sin(2*pi*f_yr*t + pi/2) = cos(2*pi*f_yr*t) + result_0 = delay_func(log10_Amp=-6.0, phase=0.0, alpha=2.0) + result_90 = delay_func(log10_Amp=-6.0, phase=np.pi/2, alpha=2.0) + + # These should be different (unless by chance the time gives sin=0) + # Just check they're both valid numbers + assert np.isfinite(result_0[0]) + assert np.isfinite(result_90[0]) + + def test_chromatic_annual_frequency_dependence(self): + """Test frequency-dependent scaling with alpha parameter.""" + psr = MockPsr(toas=np.array([55001.0 * 86400, 55001.0 * 86400]), + freqs=np.array([2800.0, 1400.0])) + delay_func = deterministic.chromatic_annual(psr, fref=1400.0) + + result = delay_func(log10_Amp=-6.0, phase=0.0, alpha=2.0) + + # At alpha=2, delay should scale as (fref/f)^2 + # freq[0]=2800: (1400/2800)^2 = 0.25 + # freq[1]=1400: (1400/1400)^2 = 1.0 + # So result[1] should be 4x result[0] + ratio = result[1] / result[0] + assert np.abs(ratio - 4.0) < 1e-10 + + +class TestChromaticGaussian: + """Tests for chromatic_gaussian delay function.""" + + def test_chromatic_gaussian_returns_callable(self): + """Test that chromatic_gaussian returns a callable function.""" + psr = MockPsr() + delay_func = deterministic.chromatic_gaussian(psr, fref=1400.0) + + assert callable(delay_func) + assert delay_func.__name__ == "chromatic_gaussian_delay" + + def test_chromatic_gaussian_output_shape(self): + """Test that output has correct shape.""" + psr = MockPsr() + delay_func = deterministic.chromatic_gaussian(psr, fref=1400.0) + + result = delay_func(t0=55001.0, log10_Amp=-6.0, log10_sigma=1.0, + sign_param=1.0, alpha=2.0) + + assert result.shape == psr.toas.shape + + def test_chromatic_gaussian_peak_at_t0(self): + """Test that Gaussian peaks at t0.""" + psr = MockPsr() + delay_func = deterministic.chromatic_gaussian(psr, fref=1400.0) + + result = delay_func(t0=55001.0, log10_Amp=-6.0, log10_sigma=0.5, + sign_param=1.0, alpha=2.0) + + # Peak should be at t0 (middle element) + assert np.abs(result[1]) > np.abs(result[0]) + assert np.abs(result[1]) > np.abs(result[2]) + + def test_chromatic_gaussian_sign(self): + """Test that sign_param correctly changes sign of delay.""" + psr = MockPsr() + delay_func = deterministic.chromatic_gaussian(psr, fref=1400.0) + + result_pos = delay_func(t0=55001.0, log10_Amp=-6.0, log10_sigma=1.0, + sign_param=1.0, alpha=2.0) + result_neg = delay_func(t0=55001.0, log10_Amp=-6.0, log10_sigma=1.0, + sign_param=-1.0, alpha=2.0) + + np.testing.assert_array_almost_equal(result_pos, -result_neg) + + def test_chromatic_gaussian_frequency_dependence(self): + """Test frequency-dependent scaling with alpha parameter.""" + psr = MockPsr(toas=np.array([55001.0 * 86400, 55001.0 * 86400]), + freqs=np.array([2800.0, 1400.0])) + delay_func = deterministic.chromatic_gaussian(psr, fref=1400.0) + + result = delay_func(t0=55001.0, log10_Amp=-6.0, log10_sigma=1.0, + sign_param=1.0, alpha=2.0) + + # At alpha=2, delay should scale as (fref/f)^2 + # freq[0]=2800: (1400/2800)^2 = 0.25 + # freq[1]=1400: (1400/1400)^2 = 1.0 + # So result[1] should be 4x result[0] + ratio = result[1] / result[0] + assert np.abs(ratio - 4.0) < 1e-10 + + +class TestOrthometricShapiro: + """Tests for orthometric_shapiro delay function.""" + + def test_orthometric_shapiro_returns_callable(self): + """Test that orthometric_shapiro returns a callable function.""" + psr = MockPsr() + binphase = np.array([0.0, np.pi/2, np.pi]) + delay_func = deterministic.orthometric_shapiro(psr, binphase) + + assert callable(delay_func) + assert delay_func.__name__ == "orthometric_shapiro_delay" + + def test_orthometric_shapiro_output_shape(self): + """Test that output has correct shape.""" + psr = MockPsr() + binphase = np.array([0.0, np.pi/2, np.pi]) + delay_func = deterministic.orthometric_shapiro(psr, binphase) + + result = delay_func(h3=1e-7, stig=0.9) + + assert result.shape == psr.toas.shape + + def test_orthometric_shapiro_binphase_shape_mismatch(self): + """Test that ValueError is raised when binphase shape doesn't match toas.""" + psr = MockPsr() + binphase = np.array([0.0, np.pi/2]) # Wrong shape + + with pytest.raises(ValueError, match="binphase must have the same shape"): + deterministic.orthometric_shapiro(psr, binphase) + + def test_orthometric_shapiro_values_finite(self): + """Test that Shapiro delay produces finite values.""" + psr = MockPsr() + binphase = np.array([0.0, np.pi/4, np.pi/2]) + delay_func = deterministic.orthometric_shapiro(psr, binphase) + + result = delay_func(h3=1e-7, stig=0.5) + + assert np.all(np.isfinite(result)) + + def test_orthometric_shapiro_equation_form(self): + """Test that delay follows the expected equation form from Freire & Wex 2010.""" + psr = MockPsr(toas=np.array([55000.0 * 86400]), freqs=np.array([2800.0])) + binphase = np.array([0.0]) # sin(0) = 0 + delay_func = deterministic.orthometric_shapiro(psr, binphase) + + h3 = 1e-7 + stig = 0.5 + + result = delay_func(h3=h3, stig=stig) + + # At binphase=0, sin(binphase)=0, so: + # Delta_s = -(2*h3/stig^3) * log(1 + stig^2) + expected = -(2.0 * h3 / stig**3) * np.log(1 + stig**2) + + np.testing.assert_almost_equal(result[0], expected, decimal=15) diff --git a/tests/test_solar.py b/tests/test_solar.py new file mode 100644 index 0000000..7587279 --- /dev/null +++ b/tests/test_solar.py @@ -0,0 +1,357 @@ +#!/usr/bin/env python3 +"""Tests for discovery.solar module""" + +import pytest +import numpy as np + +import jax +jax.config.update('jax_enable_x64', True) +import jax.numpy as jnp + +from discovery import solar, matrix, const + + +class MockPsr: + """Mock pulsar object for testing solar wind functions.""" + + def __init__(self, toas=None, freqs=None, name='J0000+0000'): + # Default TOAs in seconds (MJD * 86400) + self.toas = toas if toas is not None else np.array([ + 55000.0 * 86400, 55001.0 * 86400, 55002.0 * 86400 + ]) + self.freqs = freqs if freqs is not None else np.array([1400.0, 1400.0, 1400.0]) + self.name = name + + # Create mock solar system ephemeris + # planetssb shape: (n_toas, n_planets, 6) - we need Earth (index 2) + # sunssb shape: (n_toas, 6) + n_toas = len(self.toas) + + # Simple geometry: Earth at 1 AU in x-direction, Sun at origin + self.planetssb = np.zeros((n_toas, 10, 6)) + # Earth position at 1 AU in x-direction (in light-seconds) + au_light_sec = const.AU / const.c + self.planetssb[:, 2, 0] = au_light_sec # x-position + self.planetssb[:, 2, 1] = 0.0 # y-position + self.planetssb[:, 2, 2] = 0.0 # z-position + + # Sun at origin + self.sunssb = np.zeros((n_toas, 6)) + + # Pulsar position (unit vector pointing in z-direction) + self.pos = np.array([0.0, 0.0, 1.0]) + # Replicate for each TOA + self.pos_t = np.tile(self.pos, (n_toas, 1)) + + +class TestThetaImpact: + """Tests for theta_impact function.""" + + def test_theta_impact_returns_four_values(self): + """Test that theta_impact returns four values.""" + psr = MockPsr() + result = solar.theta_impact(psr) + + assert len(result) == 4 + theta, R_earth, b, z_earth = result + assert theta.shape == (len(psr.toas),) + assert R_earth.shape == (len(psr.toas),) + assert b.shape == (len(psr.toas),) + assert z_earth.shape == (len(psr.toas),) + + def test_theta_impact_perpendicular_geometry(self): + """Test theta_impact with perpendicular geometry (pulsar at 90 deg from Sun).""" + psr = MockPsr() + theta, R_earth, b, z_earth = solar.theta_impact(psr) + + # With pulsar in z-direction and Earth in x-direction from Sun, + # theta should be pi/2 (90 degrees) + np.testing.assert_allclose(theta, np.pi / 2, rtol=1e-6) + + # R_earth should be approximately 1 AU in light-seconds + au_light_sec = const.AU / const.c + np.testing.assert_allclose(R_earth, au_light_sec, rtol=1e-6) + + def test_theta_impact_positive_values(self): + """Test that R_earth and b are positive.""" + psr = MockPsr() + theta, R_earth, b, z_earth = solar.theta_impact(psr) + + assert np.all(R_earth > 0) + assert np.all(b >= 0) + assert np.all(theta >= 0) + assert np.all(theta <= np.pi) + + +class TestDmSolar: + """Tests for dm_solar and related functions.""" + + def test_dm_solar_returns_correct_shape_scalar(self): + """Test that dm_solar returns correct shape for scalar inputs.""" + n_earth = 5.0 + theta = np.pi / 2 + r_earth = const.AU / const.c + + result = solar.dm_solar(n_earth, theta, r_earth) + assert np.isscalar(result) or result.shape == () + + def test_dm_solar_positive(self): + """Test that dm_solar returns positive values for arrays.""" + n_earth = 5.0 + theta = np.linspace(0.1, np.pi - 0.1, 10) + r_earth = const.AU / const.c + + result = solar.dm_solar(n_earth, theta, r_earth) + assert result.shape == theta.shape + assert np.all(result > 0) + + def test_dm_solar_scales_with_density(self): + """Test that dm_solar scales linearly with electron density.""" + theta = np.pi / 2 + r_earth = const.AU / const.c + + dm1 = solar.dm_solar(5.0, theta, r_earth) + dm2 = solar.dm_solar(10.0, theta, r_earth) + + np.testing.assert_allclose(dm2 / dm1, 2.0, rtol=1e-10) + + def test_dm_solar_close_approach(self): + """Test dm_solar uses close approach approximation near pi.""" + n_earth = 5.0 + r_earth = const.AU / const.c + + # Test at threshold (pi - theta = 1e-5) + theta_close = np.pi - 1e-6 # Should use close approximation + theta_far = np.pi - 1e-4 # Should use regular formula + + result_close = solar.dm_solar(n_earth, theta_close, r_earth) + result_far = solar.dm_solar(n_earth, theta_far, r_earth) + + # Both should give positive finite values + assert np.isfinite(result_close) + assert np.isfinite(result_far) + assert result_close > 0 + assert result_far > 0 + + def test_dm_solar_continuous_at_boundary(self): + """Test that dm_solar is continuous at the boundary between approximations.""" + n_earth = 5.0 + r_earth = const.AU / const.c + + # Test near the boundary (pi - theta = 1e-5) + theta_just_below = np.pi - 1e-5 - 1e-7 + theta_just_above = np.pi - 1e-5 + 1e-7 + + result_below = solar.dm_solar(n_earth, theta_just_below, r_earth) + result_above = solar.dm_solar(n_earth, theta_just_above, r_earth) + + # Results should be very close (within 1%) + np.testing.assert_allclose(result_below, result_above, rtol=1e-2) + + +class TestMakeSolardm: + """Tests for make_solardm function.""" + + def test_make_solardm_returns_callable(self): + """Test that make_solardm returns a callable function.""" + psr = MockPsr() + solardm_func = solar.make_solardm(psr) + + assert callable(solardm_func) + + def test_make_solardm_output_shape(self): + """Test that the returned function produces correct output shape.""" + psr = MockPsr() + solardm_func = solar.make_solardm(psr) + + n_earth = 5.0 + result = solardm_func(n_earth) + + assert result.shape == psr.toas.shape + + def test_make_solardm_scales_linearly(self): + """Test that output scales linearly with n_earth.""" + psr = MockPsr() + solardm_func = solar.make_solardm(psr) + + result1 = solardm_func(5.0) + result2 = solardm_func(10.0) + + np.testing.assert_allclose(result2 / result1, 2.0, rtol=1e-10) + + def test_make_solardm_frequency_dependence(self): + """Test frequency-dependent scaling (proportional to 1/f^2).""" + psr = MockPsr( + freqs=np.array([1400.0, 2800.0, 700.0]) + ) + solardm_func = solar.make_solardm(psr) + + result = solardm_func(5.0) + + # Ratio of delays should scale as (f1/f2)^2 + # delay at 700 MHz should be 4x delay at 1400 MHz + # This is approximate due to geometry factors + assert result[2] > result[0] # Lower frequency has larger delay + + +class TestFourierbasisSolarDm: + """Tests for fourierbasis_solar_dm function.""" + + def test_fourierbasis_solar_dm_output_shapes(self): + """Test that fourierbasis_solar_dm returns three values with correct shapes.""" + psr = MockPsr() + components = 10 + + result = solar.fourierbasis_solar_dm(psr, components) + assert len(result) == 3 + + f, df, fmat = result + + # f should have length 2*components (repeated for sin/cos pairs) + assert len(f) == 2 * components + # df should be array of length 2*components + assert len(df) == 2 * components + # fmat should have shape (n_toas, 2*components) + assert fmat.shape == (len(psr.toas), 2 * components) + + +class TestMakegpTimedomainSolarDm: + """Tests for makegp_timedomain_solar_dm function.""" + + def test_makegp_timedomain_solar_dm_returns_variablegp(self): + """Test that function returns a VariableGP object.""" + psr = MockPsr() + + # Simple covariance function + def simple_cov(tau, log10_sigma, log10_ell): + sigma = 10**log10_sigma + ell = 10**log10_ell + return sigma**2 * jnp.exp(-tau / ell) + + result = solar.makegp_timedomain_solar_dm(psr, simple_cov, dt=86400.0) + + assert isinstance(result, matrix.VariableGP) + + def test_makegp_timedomain_solar_dm_with_dt(self): + """Test with custom time bin width.""" + psr = MockPsr() + + def simple_cov(tau, log10_sigma): + return 10**(2 * log10_sigma) * jnp.ones_like(tau) + + result = solar.makegp_timedomain_solar_dm(psr, simple_cov, dt=43200.0) # 12 hours + + assert isinstance(result, matrix.VariableGP) + + def test_makegp_timedomain_solar_dm_parameter_naming(self): + """Test that parameter names are generated correctly.""" + psr = MockPsr(name='J1234+5678') + + def simple_cov(tau, log10_sigma, log10_ell): + return 10**(2 * log10_sigma) * jnp.exp(-tau / 10**log10_ell) + + result = solar.makegp_timedomain_solar_dm(psr, simple_cov, name='sw_dm') + + # Check that the covariance matrix has params attribute + assert hasattr(result.Phi, 'params') + params = result.Phi.params + + # Should have pulsar-specific parameter names + assert any('J1234+5678' in p for p in params) + assert any('sw_dm' in p for p in params) + + def test_makegp_timedomain_solar_dm_common_parameters(self): + """Test with common (shared) parameters.""" + psr = MockPsr(name='J1234+5678') + + def simple_cov(tau, log10_sigma, log10_ell): + return 10**(2 * log10_sigma) * jnp.exp(-tau / 10**log10_ell) + + result = solar.makegp_timedomain_solar_dm( + psr, simple_cov, common=['log10_ell'], name='sw_dm' + ) + + params = result.Phi.params + + # log10_ell should be common (not pulsar-specific) + assert 'log10_ell' in params + # log10_sigma should be pulsar-specific + assert any('log10_sigma' in p and 'J1234+5678' in p for p in params) + + def test_makegp_timedomain_solar_dm_covariance_evaluation(self): + """Test that the covariance function can be evaluated.""" + psr = MockPsr(name='J1234+5678') + + def exponential_cov(tau, log10_sigma, log10_ell): + return 10**(2 * log10_sigma) * jnp.exp(-tau / 10**log10_ell) + + result = solar.makegp_timedomain_solar_dm(psr, exponential_cov, dt=86400.0) + + # Create test parameters + test_params = { + 'J1234+5678_timedomain_sw_gp_log10_sigma': -6.0, + 'J1234+5678_timedomain_sw_gp_log10_ell': 1.5, + } + + # Evaluate the covariance through the GP structure + # The Phi object should have getN method + cov_matrix = result.Phi.getN(test_params) + + # Check that we get a matrix + assert cov_matrix.ndim == 2 + # Should be square matrix + assert cov_matrix.shape[0] == cov_matrix.shape[1] + # Should be positive on diagonal + assert np.all(np.diag(cov_matrix) > 0) + + +class TestIntegration: + """Integration tests combining multiple functions.""" + + def test_solar_wind_pipeline(self): + """Test the complete solar wind modeling pipeline.""" + # Create a mock pulsar with multiple TOAs + psr = MockPsr( + toas=np.linspace(55000.0 * 86400, 55100.0 * 86400, 50), + freqs=np.full(50, 1400.0) + ) + + # Calculate solar geometry + theta, R_earth, b, z_earth = solar.theta_impact(psr) + assert theta.shape == (50,) + + # Calculate DM contribution + dm = solar.dm_solar(5.0, theta, R_earth) + assert dm.shape == (50,) + assert np.all(dm > 0) + + # Create solar DM function + solardm_func = solar.make_solardm(psr) + dm_delays = solardm_func(5.0) + assert dm_delays.shape == (50,) + + def test_gp_construction_pipeline(self): + """Test GP construction with solar wind geometry.""" + psr = MockPsr( + toas=np.linspace(55000.0 * 86400, 55010.0 * 86400, 20), + freqs=np.full(20, 1400.0) + ) + + # Create time-domain GP + def exponential_cov(tau, log10_sigma, log10_ell): + return 10**(2 * log10_sigma) * jnp.exp(-tau / 10**log10_ell) + + gp = solar.makegp_timedomain_solar_dm(psr, exponential_cov, dt=86400.0) + + # Check GP structure + assert isinstance(gp, matrix.VariableGP) + assert hasattr(gp, 'Phi') # Covariance matrix + assert hasattr(gp, 'F') # Basis matrix + + # Basis should have correct shape + # (n_toas, n_bins) where n_bins depends on quantization + assert gp.F.shape[0] == len(psr.toas) + + +if __name__ == '__main__': + pytest.main([__file__, '-v'])