Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 85 additions & 3 deletions src/discovery/solar.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,95 @@ def chromaticdelay(toas, freqs, t0, log10_Amp, log10_tau, idx):
return matrix.jnp.where(dt > 0.0, -1.0 * (10**log10_Amp) * matrix.jnp.exp(-dt / (10**log10_tau)) * invnormfreqs**idx, 0.0)

def make_chromaticdelay(psr, idx=None):
"""From enterprise_extensions: pre-calculate chromatic exponential-dip delay."""

"""Create a closure function for calculating chromatic exponential-dip delay.

This function acts as a factory. It pre-calculates TOA
and frequency-dependent terms from a pulsar object and returns a new
function (`decay`). This returned function computes the time delay
induced by a chromatic exponential event for specific event parameters.

The closure mechanism works as follows: The inner function `decay` retains
access to the variables `toadays`, `invnormfreqs`, `ln_10`, and
`ln_invnormfreqs` calculated in the scope of `make_chromaticdelay`, even
after `make_chromaticdelay` has finished executing. If `idx` is provided
to `make_chromaticdelay`, it is also fixed for the returned `decay`
function using `functools.partial`.

Parameters
----------
psr : discovery.Pulsar
Pulsar object.
idx : float, optional
The chromatic index defining the delay's radio-frequency dependence.
If `None` (default), the returned `decay` function will require `idx`
as an argument. If a float is provided, this value is fixed for the
returned `decay` function.

Returns
-------
Callable
A function `decay` that calculates the chromatic delay.
Its signature depends on whether `idx` was provided to
`make_chromaticdelay`:

- If `idx` was provided: `decay(t0, log10_Amp, log10_tau)`
- If `idx` was `None`: `decay(t0, log10_Amp, log10_tau, idx)`

The arguments for the returned `decay` function are:
- `t0` (float): The MJD of the exponential dip event.
- `log10_Amp` (float): Log10 of the amplitude of the signal (unitless).
The actual amplitude `A` is `10**log10_Amp`. The amplitude is
defined at the reference frequency (1400 MHz) and time `t0`.
- `log10_tau` (float): Log10 of the decay timescale `tau` (in days).
The actual timescale is `tau = 10**log10_tau`.
- `idx` (float, required only if `idx` was `None` in the outer function):
The chromatic index.

The `decay` function returns an array of delays (in seconds) corresponding
to each TOA/frequency pair in the `psr` object.

Notes
-----
Without the double `matrix.jnp.where` in this function, the gradients can return
`NaN`. This is because JAX still evaluates the gradient at the `False` positions
of the `where` [1]_, and for some pulsars `matrix.jnp.exp(-dt / 10**log10_tau)`
is large enough to overflow and become an `inf`! We don't care about the opposite
direction, because the underflow just becomes a zero.

The solution to the problem is the "double where" trick [1]_, which this
method implements. It is possible to instead move the `matrix.jnp.where`
inside the exponential and return `-inf` for all `dt < 0.0`, but in testing
this implementation in log space with two `where`'s is actually faster.

References
----------
.. [1] https://docs.jax.dev/en/latest/faq.html#gradients-contain-nan-where-using-where

"""
toadays, invnormfreqs = matrix.jnparray(psr.toas / const.day), matrix.jnparray(1400.0 / psr.freqs)

# Put quantities into log-space
ln_10 = matrix.jnp.log(10)
ln_invnormfreqs = matrix.jnp.log(invnormfreqs)

def decay(t0, log10_Amp, log10_tau, idx):

# Note that the usage of `jax.numpy.where` allows for the TOAs to be unordered.
# We only need to store the differences here.
dt = toadays - t0
return matrix.jnp.where(dt > 0.0, -1.0 * (10**log10_Amp) * matrix.jnp.exp(-dt / (10**log10_tau)) * invnormfreqs**idx, 0.0)

# Store this mask because we'll use it twice.
dt_mask = dt > 0.0

# Get the exponent for the exponential dip. Return 0 if the TOA is before the
# start time `t0`.
vals = matrix.jnp.where(
dt_mask,
ln_10 * log10_Amp - dt / (10**log10_tau) + idx * ln_invnormfreqs,
0.0,
)

return matrix.jnp.where(dt_mask, -1.0 * matrix.jnp.exp(vals), vals)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realize it's a bit hypocritical, given our current lack of tests, but do you think you could add a unit test to make sure this doesn't happen? This seems like a prime example of where this would be needed. Perhaps you can create a tests/test_solar.py and just give the example in the PR and check that it's not infinite?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do that!


if idx is not None:
decay = functools.partial(decay, idx=idx)
Expand Down