diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index fa05f6951..d8442794c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -1909,18 +1909,26 @@ def variance(self): return jnp.broadcast_to(var, self.batch_shape) def cdf(self, value): + try: + from tensorflow_probability.substrates.jax.math import betainc as betainc_fn + except ImportError: + from jax.scipy.special import betainc as betainc_fn + # Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions # X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5) scaled = (value - self.loc) / self.scale scaled_squared = scaled * scaled beta_value = self.df / (self.df + scaled_squared) + # when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value) # when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value) scaled_sign_half = 0.5 * jnp.sign(scaled) return ( 0.5 + scaled_sign_half - - 0.5 * jnp.sign(scaled) * betainc(0.5 * self.df, 0.5, beta_value) + - 0.5 + * jnp.sign(scaled) + * betainc_fn(0.5 * jnp.asarray(self.df), 0.5, jnp.asarray(beta_value)) ) def icdf(self, q): diff --git a/setup.py b/setup.py index f9ccf0901..e78170d1f 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ "jaxns==1.0.0", "optax>=0.0.6", "pyyaml", # flax dependency - "tensorflow_probability>=0.15.0", + "tensorflow_probability>=0.17.0", ], "examples": [ "arviz",