Skip to content

Commit 9d5d235

Browse files
authored
Draft of switch StudentT cdf to use tfp's betainc (#1475)
* Switch `StudentT` `cdf` to use tfp's `betainc` Jax's `betainc` doesn't have gradients defined for all parameters while tfp's does * Cleanup in line with PR review * Remove unneeded import as directed by linter
1 parent 3c199d2 commit 9d5d235

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

numpyro/distributions/continuous.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1909,18 +1909,26 @@ def variance(self):
19091909
return jnp.broadcast_to(var, self.batch_shape)
19101910

19111911
def cdf(self, value):
1912+
try:
1913+
from tensorflow_probability.substrates.jax.math import betainc as betainc_fn
1914+
except ImportError:
1915+
from jax.scipy.special import betainc as betainc_fn
1916+
19121917
# Ref: https://en.wikipedia.org/wiki/Student's_t-distribution#Related_distributions
19131918
# X^2 ~ F(1, df) -> df / (df + X^2) ~ Beta(df/2, 0.5)
19141919
scaled = (value - self.loc) / self.scale
19151920
scaled_squared = scaled * scaled
19161921
beta_value = self.df / (self.df + scaled_squared)
1922+
19171923
# when scaled < 0, returns 0.5 * Beta(df/2, 0.5).cdf(beta_value)
19181924
# when scaled > 0, returns 1 - 0.5 * Beta(df/2, 0.5).cdf(beta_value)
19191925
scaled_sign_half = 0.5 * jnp.sign(scaled)
19201926
return (
19211927
0.5
19221928
+ scaled_sign_half
1923-
- 0.5 * jnp.sign(scaled) * betainc(0.5 * self.df, 0.5, beta_value)
1929+
- 0.5
1930+
* jnp.sign(scaled)
1931+
* betainc_fn(0.5 * jnp.asarray(self.df), 0.5, jnp.asarray(beta_value))
19241932
)
19251933

19261934
def icdf(self, q):

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
"jaxns==1.0.0",
6565
"optax>=0.0.6",
6666
"pyyaml", # flax dependency
67-
"tensorflow_probability>=0.15.0",
67+
"tensorflow_probability>=0.17.0",
6868
],
6969
"examples": [
7070
"arviz",

0 commit comments

Comments
 (0)