jax.lax.while_loop
+ jax.grad
in Importance Sampling Routine
#30526
-
Hi, I'm implementing a simple importance sampling routine to estimate an integral. The idea is to repeatedly draw a fixed number of samples (e.g., 100 per iteration) and refine the estimate until the Monte Carlo error falls below a desired threshold. Since the number of steps needed to reach that threshold is not known in advance, I’ve used jax.lax.while_loop to express this adaptive behavior. Each loop iteration updates the running estimate and error by combining new samples with previous ones. This logic works correctly on its own, but I run into an issue when trying to differentiate the outer function using
This becomes a problem because the loop is embedded within a larger differentiable computation. I would like to know:
I've included a MRE below. Thanks in advance. import functools as ft
from typing import Tuple, TypeAlias
import jax
from jax import nn as jnn, numpy as jnp, random as jrd
from jaxtyping import Array, PRNGKeyArray
StateT: TypeAlias = Tuple[
Array, # old monte-carlo-estimate
Array, # old error
Array, # old size
PRNGKeyArray, # old key
]
"""State of the Monte Carlo estimation process."""
@ft.partial(jax.jit, static_argnames=("n_samples",))
def _mvn_samples(
loc: Array, scale_tril: Array, n_samples: int, key: PRNGKeyArray
) -> Array:
"""Generate samples from a multivariate normal distribution using method from
`numpyro.distributions.MultivariateNormal`.
Parameters
----------
loc : Array
Mean vector of the multivariate normal distribution.
scale_tril : Array
Lower triangular matrix of the covariance matrix (Cholesky decomposition).
n_samples : int
Number of samples to generate.
key : PRNGKeyArray
JAX random key for sampling.
Returns
-------
Array
Samples drawn from the multivariate normal distribution.
"""
eps = jrd.normal(key, shape=(n_samples, *loc.shape))
samples = loc + jnp.squeeze(jnp.matmul(scale_tril, eps[..., jnp.newaxis]), axis=-1)
return samples
@jax.jit
def _monte_carlo_estimate_and_error(log_probs: Array, N: Array) -> Tuple[Array, Array]:
"""Computes the Monte Carlo estimate and error for the given log probabilities.
Parameters
----------
log_probs : Array
Log probabilities of the samples.
N : int
Number of samples used for the estimate.
Returns
-------
Tuple[Array, Array]
Monte Carlo estimate and error.
"""
mask = ~jnp.isneginf(log_probs)
moment_1 = jnp.exp(jnn.logsumexp(log_probs, where=mask, axis=-1)) / N
moment_2 = jnp.exp(jnn.logsumexp(2.0 * log_probs, where=mask, axis=-1)) / N
error = jnp.sqrt((moment_2 - jnp.square(moment_1)) / (N - 1.0))
return moment_1, error
@jax.jit
def _combine_monte_carlo_estimates(
estimates_1: Array, estimates_2: Array, N_1: int, N_2: int
) -> Array:
r"""Combine two Monte Carlo estimates into a single estimate using the formula:
.. math::
\hat{\mu} = \frac{N_1 \hat{\mu}_1 + N_2 \hat{\mu}_2}{N_1 + N_2}
Parameters
----------
estimates_1 : Array
First Monte Carlo estimate :math:`\hat{\mu}_1`.
estimates_2 : Array
Second Monte Carlo estimate :math:`\hat{\mu}_2`.
N_1 : int
Number of samples used for the first estimate :math:`N_1`.
N_2 : int
Number of samples used for the second estimate :math:`N_2`.
Returns
-------
Array
Combined Monte Carlo estimate :math:`\hat{\mu}`.
"""
combined_estimate = (N_1 * estimates_1 + N_2 * estimates_2) / (N_1 + N_2)
return combined_estimate
@jax.jit
def _combine_monte_carlo_errors(
error_1: Array,
error_2: Array,
estimate_1: Array,
estimate_2: Array,
estimate_3: Array,
N_1: int,
N_2: int,
) -> Array:
r"""Combine two Monte Carlo errors into a single error estimate using the formula:
.. math::
\hat{\epsilon}=\sqrt{\frac{1}{N_3(N_3-1)}\sum_{k=1}^{2}\left\{N_k(N_k-1)\hat{\epsilon}_k^2+N_k\hat{\mu}^2_k\right\}-\frac{1}{N_3-1}\hat{\mu}^2}
where, :math:`N_3 = N_1 + N_2` is the total number of samples.
_extended_summary_
Parameters
----------
error_1 : Array
Error of the first Monte Carlo estimate :math:`\hat{\epsilon}_1`.
error_2 : Array
Error of the second Monte Carlo estimate :math:`\hat{\epsilon}_2`.
estimate_1 : Array
Estimate of the first Monte Carlo estimate :math:`\hat{\mu}_1`.
estimate_2 : Array
Estimate of the second Monte Carlo estimate :math:`\hat{\mu}_2`.
estimate_3 : Array
Estimate of the combined Monte Carlo estimate :math:`\hat{\mu}`.
N_1 : int
Number of samples used for the first estimate :math:`N_1`.
N_2 : int
Number of samples used for the second estimate :math:`N_2`.
Returns
-------
Array
Combined Monte Carlo error estimate :math:`\hat{\epsilon}`.
"""
N_3 = N_1 + N_2
sum_prob_sq_1 = N_1 * ((N_1 - 1.0) * jnp.square(error_1) + jnp.square(estimate_1))
sum_prob_sq_2 = N_2 * ((N_2 - 1.0) * jnp.square(error_2) + jnp.square(estimate_2))
combined_error_sq = -jnp.square(estimate_3) / (N_3 - 1.0)
combined_error_sq += (sum_prob_sq_1 + sum_prob_sq_2) / N_3 / (N_3 - 1.0)
combined_error = jnp.sqrt(combined_error_sq)
return combined_error
@jax.jit
def _error_fn(state: StateT) -> Array:
"""Check if the error in the Monte Carlo estimation is below a threshold.
Parameters
----------
state : StateT
The state of the Monte Carlo estimation.
Returns
-------
Array
A boolean array indicating whether the error is below the threshold (0.01).
"""
_, error, _, _ = state
return jnp.less_equal(error, 0.01)
def f(alpha: Array) -> Array:
@jax.jit
def scan_fn(carry: Array, rng_key: PRNGKeyArray) -> Tuple[Array, None]:
@jax.jit
def while_body_fn(state: StateT) -> StateT:
estimate_1, error_1, N_1, rng_key = state
N_2 = 100
rng_key, subkey = jrd.split(rng_key)
data = jax.random.uniform(subkey, (N_2,))
log_prob = alpha * data
estimate_2, error_2 = _monte_carlo_estimate_and_error(log_prob, N_1)
estimate_3 = _combine_monte_carlo_estimates(
estimate_1, estimate_2, N_1, N_2
)
error_3 = _combine_monte_carlo_errors(
error_1,
error_2,
estimate_1,
estimate_2,
estimate_3,
N_1,
N_2,
)
return estimate_3, error_3, N_1 + N_2, rng_key
log_likelihood, _, _, _ = jax.lax.while_loop(
_error_fn,
while_body_fn,
(jnp.zeros(()), jnp.ones(()), jnp.zeros(()), rng_key),
)
return carry + log_likelihood, None
rng_key: PRNGKeyArray = jrd.PRNGKey(0)
n_events: int = 10
keys = jrd.split(rng_key, (n_events,))
total_log_likelihood, _ = jax.lax.scan(
scan_fn, # type: ignore[arg-type]
jnp.zeros(()),
keys,
length=n_events,
)
return total_log_likelihood
print(jax.grad(f)(2.0)) |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
You can use |
Beta Was this translation helpful? Give feedback.
You can use
equinox.internal.while_loop
instead, which is reverse-mode autodifferentiable. Equinox