diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 686c6ebde..1a3285147 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -179,6 +179,80 @@ def icdf(self, value: ArrayLike) -> ArrayLike: ) +@jax.custom_jvp +def _beta_log_prob(value, concentration1, concentration0): + """ + Compute Beta log probability with custom gradients to handle edge cases. + + When concentration1=1 and value=0, or concentration0=1 and value=1, + the standard formula involves log(0) * 0 which should be 0, but has + undefined gradients. We use custom_jvp to define proper gradients. + """ + return ( + xlogy(concentration1 - 1.0, value) + + xlogy(concentration0 - 1.0, 1.0 - value) + - betaln(concentration1, concentration0) + ) + + +@_beta_log_prob.defjvp +def _beta_log_prob_jvp(primals, tangents): + """Custom JVP for Beta log_prob handling edge cases at boundaries.""" + value, concentration1, concentration0 = primals + value_dot, concentration1_dot, concentration0_dot = tangents + primal_out = _beta_log_prob(value, concentration1, concentration0) + + # Gradient w.r.t. value - safe division, zero at edge cases + safe_val = jnp.where(value == 0.0, 1.0, value) + safe_one_minus = jnp.where(value == 1.0, 1.0, 1.0 - value) + grad_val = (concentration1 - 1.0) / safe_val - ( + concentration0 - 1.0 + ) / safe_one_minus + grad_val = jnp.where( + ((value == 0.0) & (concentration1 == 1.0)) + | ((value == 1.0) & (concentration0 == 1.0)), + 0.0, + grad_val, + ) + + # Gradients w.r.t. concentrations - safe log (0 instead of -inf) + dsum = digamma(concentration1 + concentration0) + grad_c1 = ( + jnp.where(value == 0.0, 0.0, jnp.log(value)) - digamma(concentration1) + dsum + ) + grad_c0 = ( + jnp.where(value == 1.0, 0.0, jnp.log(1.0 - value)) + - digamma(concentration0) + + dsum + ) + + # Build tangent output - handle Zero tangents properly + from jax.interpreters import ad + + def is_tangent_active(tangent): + """Check if tangent is active (not Zero or float0).""" + if isinstance(tangent, ad.Zero): + return False + # Check for float0 dtype (float0 has itemsize 0) + if ( + hasattr(tangent, "dtype") + and hasattr(tangent.dtype, "itemsize") + and tangent.dtype.itemsize == 0 + ): + return False + return True + + tangent_out = 0.0 + if is_tangent_active(value_dot): + tangent_out = tangent_out + grad_val * value_dot + if is_tangent_active(concentration1_dot): + tangent_out = tangent_out + grad_c1 * concentration1_dot + if is_tangent_active(concentration0_dot): + tangent_out = tangent_out + grad_c0 * concentration0_dot + + return primal_out, tangent_out + + class Beta(Distribution): arg_constraints = { "concentration1": constraints.positive, @@ -216,7 +290,9 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - return self._dirichlet.log_prob(jnp.stack([value, 1.0 - value], -1)) + # Compute Beta log_prob directly using the formula with custom gradients + # to handle edge cases where concentration=1 and value is at boundary + return _beta_log_prob(value, self.concentration1, self.concentration0) @property def mean(self) -> ArrayLike: diff --git a/test/test_distributions.py b/test/test_distributions.py index 7edc19226..ace1839b0 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -4486,3 +4486,131 @@ def test_interval_censored_validate_sample( censored_dist.log_prob(value) else: censored_dist.log_prob(value) # Should not raise + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value", + argvalues=[ + (1.0, 8.0, 0.0), + (8.0, 1.0, 1.0), + ], + ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"], +) +def test_beta_logprob_edge_cases(concentration1, concentration0, value): + """Test Beta distribution with concentration=1 gives finite log probability at boundary.""" + beta_dist = dist.Beta(concentration1, concentration0) + log_prob = beta_dist.log_prob(value) + + assert not jnp.isnan(log_prob), ( + f"Beta({concentration1},{concentration0}).log_prob({value}) should not be NaN" + ) + assert jnp.isfinite(log_prob), ( + f"Beta({concentration1},{concentration0}).log_prob({value}) should be finite" + ) + + +def test_beta_logprob_edge_case_consistency_small_values(): + """Test that edge case values are consistent with small deviation values.""" + beta_dist = dist.Beta(1.0, 8.0) + beta_dist2 = dist.Beta(8.0, 1.0) + + # At boundary + log_prob_at_zero = beta_dist.log_prob(0.0) + log_prob_at_one = beta_dist2.log_prob(1.0) + + # Very close to boundary + small_value = 1e-10 + log_prob_small = beta_dist.log_prob(small_value) + log_prob_close_to_one = beta_dist2.log_prob(1.0 - small_value) + + # Edge case values should be close to small deviation values + assert jnp.abs(log_prob_at_zero - log_prob_small) < 1e-5 + assert jnp.abs(log_prob_at_one - log_prob_close_to_one) < 1e-5 + + +def test_beta_logprob_edge_case_non_boundary_values(): + """Test that Beta with concentration=1 still works for non-boundary values.""" + beta_dist = dist.Beta(1.0, 8.0) + beta_dist2 = dist.Beta(8.0, 1.0) + + assert jnp.isfinite(beta_dist.log_prob(0.5)) + assert jnp.isfinite(beta_dist2.log_prob(0.5)) + + +def test_beta_logprob_boundary_non_edge_cases(): + """Test that non-edge cases (concentration > 1) still give -inf at boundaries.""" + beta_dist3 = dist.Beta(2.0, 8.0) + beta_dist4 = dist.Beta(8.0, 2.0) + + assert jnp.isneginf(beta_dist3.log_prob(0.0)) + assert jnp.isneginf(beta_dist4.log_prob(1.0)) + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value,grad_param,grad_value", + argvalues=[ + (1.0, 8.0, 0.0, "value", 0.0), + (8.0, 1.0, 1.0, "value", 1.0), + (1.0, 8.0, 0.0, "concentration1", 1.0), + (1.0, 8.0, 0.0, "concentration0", 8.0), + (8.0, 1.0, 1.0, "concentration1", 8.0), + (8.0, 1.0, 1.0, "concentration0", 1.0), + ], + ids=[ + "Beta(1,8) at x=0", + "Beta(8,1) at x=1", + "Beta(1,8) at concentration1=1", + "Beta(1,8) at concentration0=8", + "Beta(8,1) at concentration1=8", + "Beta(8,1) at concentration0=1", + ], +) +def test_beta_gradient_edge_cases_single_param( + concentration1, concentration0, value, grad_param, grad_value +): + """Test that gradients w.r.t. individual parameters are finite at edge cases.""" + if grad_param == "value": + + def log_prob_fn(x): + return dist.Beta(concentration1, concentration0).log_prob(x) + + grad = jax.grad(log_prob_fn)(value) + elif grad_param == "concentration1": + + def log_prob_fn(c1): + return dist.Beta(c1, concentration0).log_prob(value) + + grad = jax.grad(log_prob_fn)(grad_value) + else: # concentration0 + + def log_prob_fn(c0): + return dist.Beta(concentration1, c0).log_prob(value) + + grad = jax.grad(log_prob_fn)(grad_value) + + assert jnp.isfinite(grad), ( + f"Gradient w.r.t. {grad_param} for Beta({concentration1},{concentration0}) " + f"at x={value} should be finite" + ) + + +@pytest.mark.parametrize( + argnames="concentration1,concentration0,value", + argvalues=[ + (1.0, 8.0, 0.0), + (8.0, 1.0, 1.0), + ], + ids=["Beta(1,8) at x=0", "Beta(8,1) at x=1"], +) +def test_beta_gradient_edge_cases_all_params(concentration1, concentration0, value): + """Test that all gradients are finite when computed simultaneously at edge cases.""" + + def log_prob_fn(params): + c1, c0, v = params + return dist.Beta(c1, c0).log_prob(v) + + grads = jax.grad(log_prob_fn)(jnp.array([concentration1, concentration0, value])) + assert jnp.all(jnp.isfinite(grads)), ( + f"All gradients for Beta({concentration1},{concentration0}) at x={value} " + f"should be finite" + )