Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 77 additions & 1 deletion numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,80 @@ def icdf(self, value: ArrayLike) -> ArrayLike:
)


@jax.custom_jvp
def _beta_log_prob(value, concentration1, concentration0):
"""
Compute Beta log probability with custom gradients to handle edge cases.

When concentration1=1 and value=0, or concentration0=1 and value=1,
the standard formula involves log(0) * 0 which should be 0, but has
undefined gradients. We use custom_jvp to define proper gradients.
"""
return (
xlogy(concentration1 - 1.0, value)
+ xlogy(concentration0 - 1.0, 1.0 - value)
- betaln(concentration1, concentration0)
)


@_beta_log_prob.defjvp
def _beta_log_prob_jvp(primals, tangents):
"""Custom JVP for Beta log_prob handling edge cases at boundaries."""
value, concentration1, concentration0 = primals
value_dot, concentration1_dot, concentration0_dot = tangents
primal_out = _beta_log_prob(value, concentration1, concentration0)

# Gradient w.r.t. value - safe division, zero at edge cases
safe_val = jnp.where(value == 0.0, 1.0, value)
safe_one_minus = jnp.where(value == 1.0, 1.0, 1.0 - value)
grad_val = (concentration1 - 1.0) / safe_val - (
concentration0 - 1.0
) / safe_one_minus
grad_val = jnp.where(
((value == 0.0) & (concentration1 == 1.0))
| ((value == 1.0) & (concentration0 == 1.0)),
0.0,
grad_val,
)

# Gradients w.r.t. concentrations - safe log (0 instead of -inf)
dsum = digamma(concentration1 + concentration0)
grad_c1 = (
jnp.where(value == 0.0, 0.0, jnp.log(value)) - digamma(concentration1) + dsum
)
grad_c0 = (
jnp.where(value == 1.0, 0.0, jnp.log(1.0 - value))
- digamma(concentration0)
+ dsum
)

# Build tangent output - handle Zero tangents properly
from jax.interpreters import ad

def is_tangent_active(tangent):
"""Check if tangent is active (not Zero or float0)."""
if isinstance(tangent, ad.Zero):
return False
# Check for float0 dtype (float0 has itemsize 0)
if (
hasattr(tangent, "dtype")
and hasattr(tangent.dtype, "itemsize")
and tangent.dtype.itemsize == 0
):
return False
return True

tangent_out = 0.0
if is_tangent_active(value_dot):
tangent_out = tangent_out + grad_val * value_dot
if is_tangent_active(concentration1_dot):
tangent_out = tangent_out + grad_c1 * concentration1_dot
if is_tangent_active(concentration0_dot):
tangent_out = tangent_out + grad_c0 * concentration0_dot

return primal_out, tangent_out


class Beta(Distribution):
arg_constraints = {
"concentration1": constraints.positive,
Expand Down Expand Up @@ -216,7 +290,9 @@ def sample(

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1))
# Compute Beta log_prob directly using the formula with custom gradients
# to handle edge cases where concentration=1 and value is at boundary
return _beta_log_prob(value, self.concentration1, self.concentration0)

@property
def mean(self) -> ArrayLike:
Expand Down
128 changes: 128 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample(
censored_dist.log_prob(value)
else:
censored_dist.log_prob(value) # Should not raise


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value",
argvalues=[
(1.0, 8.0, 0.0),
(8.0, 1.0, 1.0),
],
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_logprob_edge_cases(concentration1, concentration0, value):
"""Test Beta distribution with concentration=1 gives finite log probability at boundary."""
beta_dist = dist.Beta(concentration1, concentration0)
log_prob = beta_dist.log_prob(value)

assert not jnp.isnan(log_prob), (
f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN"
)
assert jnp.isfinite(log_prob), (
f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite"
)


def test_beta_logprob_edge_case_consistency_small_values():
"""Test that edge case values are consistent with small deviation values."""
beta_dist = dist.Beta(1.0, 8.0)
beta_dist2 = dist.Beta(8.0, 1.0)

# At boundary
log_prob_at_zero = beta_dist.log_prob(0.0)
log_prob_at_one = beta_dist2.log_prob(1.0)

# Very close to boundary
small_value = 1e-10
log_prob_small = beta_dist.log_prob(small_value)
log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value)

# Edge case values should be close to small deviation values
assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5
assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5


def test_beta_logprob_edge_case_non_boundary_values():
"""Test that Beta with concentration=1 still works for non-boundary values."""
beta_dist = dist.Beta(1.0, 8.0)
beta_dist2 = dist.Beta(8.0, 1.0)

assert jnp.isfinite(beta_dist.log_prob(0.5))
assert jnp.isfinite(beta_dist2.log_prob(0.5))


def test_beta_logprob_boundary_non_edge_cases():
"""Test that non-edge cases (concentration > 1) still give -inf at boundaries."""
beta_dist3 = dist.Beta(2.0, 8.0)
beta_dist4 = dist.Beta(8.0, 2.0)

assert jnp.isneginf(beta_dist3.log_prob(0.0))
assert jnp.isneginf(beta_dist4.log_prob(1.0))


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value,grad_param,grad_value",
argvalues=[
(1.0, 8.0, 0.0, "value", 0.0),
(8.0, 1.0, 1.0, "value", 1.0),
(1.0, 8.0, 0.0, "concentration1", 1.0),
(1.0, 8.0, 0.0, "concentration0", 8.0),
(8.0, 1.0, 1.0, "concentration1", 8.0),
(8.0, 1.0, 1.0, "concentration0", 1.0),
],
ids=[
"Beta(1,8) at x=0",
"Beta(8,1) at x=1",
"Beta(1,8) at concentration1=1",
"Beta(1,8) at concentration0=8",
"Beta(8,1) at concentration1=8",
"Beta(8,1) at concentration0=1",
],
)
def test_beta_gradient_edge_cases_single_param(
concentration1, concentration0, value, grad_param, grad_value
):
"""Test that gradients w.r.t. individual parameters are finite at edge cases."""
if grad_param == "value":

def log_prob_fn(x):
return dist.Beta(concentration1, concentration0).log_prob(x)

grad = jax.grad(log_prob_fn)(value)
elif grad_param == "concentration1":

def log_prob_fn(c1):
return dist.Beta(c1, concentration0).log_prob(value)

grad = jax.grad(log_prob_fn)(grad_value)
else: # concentration0

def log_prob_fn(c0):
return dist.Beta(concentration1, c0).log_prob(value)

grad = jax.grad(log_prob_fn)(grad_value)

assert jnp.isfinite(grad), (
f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) "
f"at x={value} should be finite"
)


@pytest.mark.parametrize(
argnames="concentration1,concentration0,value",
argvalues=[
(1.0, 8.0, 0.0),
(8.0, 1.0, 1.0),
],
ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"],
)
def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value):
"""Test that all gradients are finite when computed simultaneously at edge cases."""

def log_prob_fn(params):
c1, c0, v = params
return dist.Beta(c1, c0).log_prob(v)

grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value]))
assert jnp.all(jnp.isfinite(grads)), (
f"All gradients for Beta({concentration1},{concentration0}) at x={value} "
f"should be finite"
)