Skip to content
149 changes: 123 additions & 26 deletions numpyro/distributions/conjugate.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,35 +120,13 @@ class BetaNegativeBinomial(Distribution):
overdispersed count data. It arises as the marginal distribution when integrating out
the success probability in a negative binomial model with a beta prior.

.. math::

p &\sim \mathrm{Beta}(\alpha, \beta) \\
X \mid p &\sim \mathrm{NegativeBinomial}(n, p)

The probability mass function is:

.. math::

P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}

where :math:`B(\cdot, \cdot)` is the beta function.

:param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
Beta distribution.
:param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
Beta distribution.
:param numpy.ndarray n: positive number of successes parameter for the negative
binomial distribution.

**Properties**

- **Mean**: :math:`\frac{n \cdot \alpha}{\beta - 1}` for :math:`\beta > 1`, else undefined.
- **Variance**: for :math:`\beta > 2`, else undefined.

.. math::

\frac{n \cdot \alpha \cdot (n + \beta - 1) \cdot (\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)}

**References**

[1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution
Expand Down Expand Up @@ -187,13 +165,36 @@ def __init__(
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the sampling
procedure is:

.. math::
\begin{align*}
p &\sim \mathrm{Beta}(\alpha, \beta) \\
X \mid p &\sim \mathrm{NegativeBinomial}(n, p)
\end{align*}

It uses :class:`~numpyro.distributions.continuous.Beta` to generate samples
from the Beta distribution and
:class:`~numpyro.distributions.discrete.NegativeBinomialProbs` to generate samples
from the Negative Binomial distribution.
"""
assert is_prng_key(key)
key_beta, key_nb = random.split(key)
probs = self._beta.sample(key_beta, sample_shape)
return NegativeBinomialProbs(total_count=self.n, probs=probs).sample(key_nb)

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the log
probability mass function is:

.. math::
P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}

To ensure differentiability, the binomial coefficient is computed using
gamma functions.
"""
return (
gammaln(self.n + value)
- gammaln(self.n)
Expand All @@ -204,6 +205,14 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:

@property
def mean(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
:math:`\beta > 1`, then the mean is:

.. math::
\mathbb{E}[X] = \frac{n\alpha}{\beta - 1},

otherwise, the its undefined.
"""
return jnp.where(
self.concentration0 > 1,
self.n * self.concentration1 / (self.concentration0 - 1),
Expand All @@ -212,6 +221,15 @@ def mean(self) -> ArrayLike:

@property
def variance(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
:math:`\beta > 2`, then the variance is:

.. math::
\mathrm{Var}[X] =
\frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)},

otherwise, the its undefined.
"""
alpha = self.concentration1
beta = self.concentration0
n = self.n
Expand Down Expand Up @@ -326,7 +344,7 @@ class GammaPoisson(Distribution):
drawn from a :class:`~numpyro.distributions.Gamma` distribution.

:param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
:param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution.
:param numpy.ndarray rate: rate parameter (rate) for the Gamma distribution.
"""

arg_constraints = {
Expand All @@ -352,13 +370,33 @@ def __init__(
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the sampling
procedure is:

.. math::
\begin{align*}
\theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\
X \mid \theta &\sim \mathrm{Poisson}(\theta)
\end{align*}

It uses :class:`~numpyro.distributions.continuous.Gamma` to generate samples
from the Gamma distribution and
:class:`~numpyro.distributions.continuous.Poisson` to generate samples from the
Poisson distribution.
"""
assert is_prng_key(key)
key_gamma, key_poisson = random.split(key)
rate = self._gamma.sample(key_gamma, sample_shape)
return Poisson(rate).sample(key_poisson)

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the log
probability mass function is:

.. math::
p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)}
"""
post_value = self.concentration + value
return (
-betaln(self.concentration, value + 1)
Expand All @@ -369,13 +407,33 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:

@property
def mean(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the mean is:

.. math::
\mathbb{E}[X] = \frac{\alpha}{\lambda}
"""
return self.concentration / self.rate

@property
def variance(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the variance is:

.. math::
\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda)
"""
return self.concentration / jnp.square(self.rate) * (1 + self.rate)

def cdf(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the cumulative
distribution function is:

.. math::
F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)}
\int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt

which is the regularized incomplete beta function.
This implementation uses :func:`~jax.scipy.special.betainc`.
"""
bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0))
return bt

Expand All @@ -386,7 +444,18 @@ def NegativeBinomial(
logits: Optional[ArrayLike] = None,
*,
validate_args: Optional[bool] = None,
):
) -> GammaPoisson:
"""Factory function for Negative Binomial distribution.

:param int total_count: Number of successful trials.
:param Optional[ArrayLike] probs: Probability of success for each trial, by default None
:param Optional[ArrayLike] logits: Log-odds of success for each trial, by default None
:param Optional[bool] validate_args: Whether to validate the parameters, by default None
:return: An instance of :class:`NegativeBinomialProbs` or
:class:`NegativeBinomialLogits` depending on the provided parameters.
:rtype: GammaPoisson
:raises ValueError: If neither :code:`probs` nor :code:`logits` is specified.
"""
if probs is not None:
return NegativeBinomialProbs(total_count, probs, validate_args=validate_args)
elif logits is not None:
Expand All @@ -396,6 +465,14 @@ def NegativeBinomial(


class NegativeBinomialProbs(GammaPoisson):
r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
and :code:`probs` (:math:`p`). It is implemented as a
:math:`\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)` distribution.

:param total_count: Number of successful trials (:math:`r`).
:param probs: Probability of success for each trial (:math:`p`).
"""

arg_constraints = {
"total_count": constraints.positive,
"probs": constraints.unit_interval,
Expand All @@ -416,6 +493,15 @@ def __init__(


class NegativeBinomialLogits(GammaPoisson):
r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
and :code:`logits` (:math:`\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}`). It
is implemented as a :math:`\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))`
distribution.

:param total_count: Number of successful trials.
:param logits: Log-odds of success for each trial (:math:`\ln \frac{p}{1-p}`).
"""

arg_constraints = {
"total_count": constraints.positive,
"logits": constraints.real,
Expand All @@ -436,6 +522,14 @@ def __init__(

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))`, then the log
probability mass function is:

.. math::
\ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p)))
- k \ln(1+\exp(-\mathrm{logits}(p)))
- \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha)
"""
return -(
self.total_count * nn.softplus(self.logits)
+ value * nn.softplus(-self.logits)
Expand All @@ -444,8 +538,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:


class NegativeBinomial2(GammaPoisson):
"""
Another parameterization of GammaPoisson with `rate` is replaced by `mean`.
r"""If :math:`X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)`, then
:math:`X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})`.

:param numpy.ndarray mean: mean parameter (:math:`\mu`).
:param numpy.ndarray concentration: concentration parameter (:math:`\alpha`).
"""

arg_constraints = {
Expand Down
61 changes: 61 additions & 0 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,6 +572,14 @@ def entropy(self) -> ArrayLike:


class Gamma(Distribution):
r"""Implementation of the `Gamma distribution <https://en.wikipedia.org/wiki/Gamma_distribution>`_,
:math:`\mathrm{Gamma}(\alpha, \lambda)`, where, :math:`\alpha` is the concentration
and :math:`\lambda` is the rate.

:param ArrayLike concentration: concentration parameter :math:`\alpha` (also known as shape parameter).
:param ArrayLike rate: rate parameter :math:`\lambda` (inverse scale parameter).
"""

arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
Expand All @@ -595,12 +603,26 @@ def __init__(
def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
r"""Method to generate samples :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`.
It uses :func:`~jax.random.gamma` under the hood to generate samples.
"""
assert is_prng_key(key)
shape = sample_shape + self.batch_shape + self.event_shape
return random.gamma(key, self.concentration, shape=shape) / self.rate

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then

.. math::

f_X(x\mid \alpha, \lambda) =
\frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)},
\quad x > 0

It uses :func:`~jax.scipy.special.gammaln` to compute the logarithm of the
gamma function.
"""
normalize_term = gammaln(self.concentration) - self.concentration * jnp.log(
self.rate
)
Expand All @@ -612,19 +634,58 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:

@property
def mean(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then

.. math::
\mathbb{E}[X] = \frac{\alpha}{\lambda}
"""
return self.concentration / self.rate

@property
def variance(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then

.. math::
\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}
"""
return self.concentration / jnp.power(self.rate, 2)

def cdf(self, x):
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then

.. math::
F_X(x \mid \alpha, \lambda) = \frac{1}{\Gamma(\alpha)}
\gamma\left(\alpha, \lambda x\right)

where, :math:`\gamma(\cdot,\cdot)` is the `lower incomplete gamma function <https://en.wikipedia.org/wiki/Incomplete_gamma_function>`_.
This method uses regularized incomplete gamma function,
which is implemented as :func:`~jax.scipy.special.gammainc`.
"""
return gammainc(self.concentration, self.rate * x)

def icdf(self, q: ArrayLike) -> ArrayLike:
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then

.. math::
F^{-1}_X(q \mid \alpha, \lambda) = \frac{1}{\lambda}
\gamma^{-1}\left(\alpha, q \Gamma(\alpha)\right)

where, :math:`\gamma^{-1}(\cdot,\cdot)` is the inverse of the lower incomplete gamma function.
This method uses regularized incomplete gamma inverse function,
which is implemented as :func:`~numpyro.distributions.util.gammaincinv`.
"""
return gammaincinv(self.concentration, q) / self.rate

def entropy(self) -> ArrayLike:
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then

.. math::
H[X] = \alpha - \ln(\lambda) + \ln\Gamma(\alpha)
+ (1 - \alpha) \psi(\alpha)

where, :math:`\psi(\cdot)` is the `digamma function <https://en.wikipedia.org/wiki/Digamma_function>`_.
This methods uses which is implemented as :func:`~jax.scipy.special.digamma`.
"""
return (
self.concentration
- jnp.log(self.rate)
Expand Down
2 changes: 2 additions & 0 deletions numpyro/distributions/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,6 +780,8 @@ def wrapper(self, *args, **kwargs):
log_prob = jnp.where(mask, log_prob, -jnp.inf)
return log_prob

wrapper.__doc__ = log_prob_fn.__doc__

return wrapper


Expand Down
Loading