Skip to content

Commit 3c2c76b

Browse files
authored
doc(gh-2107): Improved documentation for Negative Binomial distributions (#2109)
* fix: add docstring to `validate_sample` wrapper * docs: add detailed mathematical descriptions and docstrings for Gamma distribution methods * docs: update `GammaPoisson` class docstrings with detailed mathematical descriptions * docs: enhance docstrings for `BetaNegativeBinomial class` with detailed mathematical descriptions * docs: enhance docstrings for Negative Binomial distribution and its parameterizations * fix: apply `functools.wraps` inside `wrapper` in `numpyro.distributions.utils.validate_sample` * fix: correction of docstring in `numpyro.distributions.truncated.DoublyTruncatedPowerLaw.log_prob`
1 parent ceac631 commit 3c2c76b

File tree

4 files changed

+195
-28
lines changed

4 files changed

+195
-28
lines changed

numpyro/distributions/conjugate.py

Lines changed: 123 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -120,35 +120,13 @@ class BetaNegativeBinomial(Distribution):
120120
overdispersed count data. It arises as the marginal distribution when integrating out
121121
the success probability in a negative binomial model with a beta prior.
122122
123-
.. math::
124-
125-
p &\sim \mathrm{Beta}(\alpha, \beta) \\
126-
X \mid p &\sim \mathrm{NegativeBinomial}(n, p)
127-
128-
The probability mass function is:
129-
130-
.. math::
131-
132-
P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}
133-
134-
where :math:`B(\cdot, \cdot)` is the beta function.
135-
136123
:param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
137124
Beta distribution.
138125
:param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
139126
Beta distribution.
140127
:param numpy.ndarray n: positive number of successes parameter for the negative
141128
binomial distribution.
142129
143-
**Properties**
144-
145-
- **Mean**: :math:`\frac{n \cdot \alpha}{\beta - 1}` for :math:`\beta > 1`, else undefined.
146-
- **Variance**: for :math:`\beta > 2`, else undefined.
147-
148-
.. math::
149-
150-
\frac{n \cdot \alpha \cdot (n + \beta - 1) \cdot (\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)}
151-
152130
**References**
153131
154132
[1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution
@@ -187,13 +165,36 @@ def __init__(
187165
def sample(
188166
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
189167
) -> ArrayLike:
168+
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the sampling
169+
procedure is:
170+
171+
.. math::
172+
\begin{align*}
173+
p &\sim \mathrm{Beta}(\alpha, \beta) \\
174+
X \mid p &\sim \mathrm{NegativeBinomial}(n, p)
175+
\end{align*}
176+
177+
It uses :class:`~numpyro.distributions.continuous.Beta` to generate samples
178+
from the Beta distribution and
179+
:class:`~numpyro.distributions.discrete.NegativeBinomialProbs` to generate samples
180+
from the Negative Binomial distribution.
181+
"""
190182
assert is_prng_key(key)
191183
key_beta, key_nb = random.split(key)
192184
probs = self._beta.sample(key_beta, sample_shape)
193185
return NegativeBinomialProbs(total_count=self.n, probs=probs).sample(key_nb)
194186

195187
@validate_sample
196188
def log_prob(self, value: ArrayLike) -> ArrayLike:
189+
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the log
190+
probability mass function is:
191+
192+
.. math::
193+
P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}
194+
195+
To ensure differentiability, the binomial coefficient is computed using
196+
gamma functions.
197+
"""
197198
return (
198199
gammaln(self.n + value)
199200
- gammaln(self.n)
@@ -204,6 +205,14 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
204205

205206
@property
206207
def mean(self) -> ArrayLike:
208+
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
209+
:math:`\beta > 1`, then the mean is:
210+
211+
.. math::
212+
\mathbb{E}[X] = \frac{n\alpha}{\beta - 1},
213+
214+
otherwise, the its undefined.
215+
"""
207216
return jnp.where(
208217
self.concentration0 > 1,
209218
self.n * self.concentration1 / (self.concentration0 - 1),
@@ -212,6 +221,15 @@ def mean(self) -> ArrayLike:
212221

213222
@property
214223
def variance(self) -> ArrayLike:
224+
r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
225+
:math:`\beta > 2`, then the variance is:
226+
227+
.. math::
228+
\mathrm{Var}[X] =
229+
\frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)},
230+
231+
otherwise, the its undefined.
232+
"""
215233
alpha = self.concentration1
216234
beta = self.concentration0
217235
n = self.n
@@ -326,7 +344,7 @@ class GammaPoisson(Distribution):
326344
drawn from a :class:`~numpyro.distributions.Gamma` distribution.
327345
328346
:param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
329-
:param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution.
347+
:param numpy.ndarray rate: rate parameter (rate) for the Gamma distribution.
330348
"""
331349

332350
arg_constraints = {
@@ -352,13 +370,33 @@ def __init__(
352370
def sample(
353371
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
354372
) -> ArrayLike:
373+
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the sampling
374+
procedure is:
375+
376+
.. math::
377+
\begin{align*}
378+
\theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\
379+
X \mid \theta &\sim \mathrm{Poisson}(\theta)
380+
\end{align*}
381+
382+
It uses :class:`~numpyro.distributions.continuous.Gamma` to generate samples
383+
from the Gamma distribution and
384+
:class:`~numpyro.distributions.continuous.Poisson` to generate samples from the
385+
Poisson distribution.
386+
"""
355387
assert is_prng_key(key)
356388
key_gamma, key_poisson = random.split(key)
357389
rate = self._gamma.sample(key_gamma, sample_shape)
358390
return Poisson(rate).sample(key_poisson)
359391

360392
@validate_sample
361393
def log_prob(self, value: ArrayLike) -> ArrayLike:
394+
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the log
395+
probability mass function is:
396+
397+
.. math::
398+
p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)}
399+
"""
362400
post_value = self.concentration + value
363401
return (
364402
-betaln(self.concentration, value + 1)
@@ -369,13 +407,33 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
369407

370408
@property
371409
def mean(self) -> ArrayLike:
410+
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the mean is:
411+
412+
.. math::
413+
\mathbb{E}[X] = \frac{\alpha}{\lambda}
414+
"""
372415
return self.concentration / self.rate
373416

374417
@property
375418
def variance(self) -> ArrayLike:
419+
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the variance is:
420+
421+
.. math::
422+
\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda)
423+
"""
376424
return self.concentration / jnp.square(self.rate) * (1 + self.rate)
377425

378426
def cdf(self, value: ArrayLike) -> ArrayLike:
427+
r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the cumulative
428+
distribution function is:
429+
430+
.. math::
431+
F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)}
432+
\int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt
433+
434+
which is the regularized incomplete beta function.
435+
This implementation uses :func:`~jax.scipy.special.betainc`.
436+
"""
379437
bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0))
380438
return bt
381439

@@ -386,7 +444,18 @@ def NegativeBinomial(
386444
logits: Optional[ArrayLike] = None,
387445
*,
388446
validate_args: Optional[bool] = None,
389-
):
447+
) -> GammaPoisson:
448+
"""Factory function for Negative Binomial distribution.
449+
450+
:param int total_count: Number of successful trials.
451+
:param Optional[ArrayLike] probs: Probability of success for each trial, by default None
452+
:param Optional[ArrayLike] logits: Log-odds of success for each trial, by default None
453+
:param Optional[bool] validate_args: Whether to validate the parameters, by default None
454+
:return: An instance of :class:`NegativeBinomialProbs` or
455+
:class:`NegativeBinomialLogits` depending on the provided parameters.
456+
:rtype: GammaPoisson
457+
:raises ValueError: If neither :code:`probs` nor :code:`logits` is specified.
458+
"""
390459
if probs is not None:
391460
return NegativeBinomialProbs(total_count, probs, validate_args=validate_args)
392461
elif logits is not None:
@@ -396,6 +465,14 @@ def NegativeBinomial(
396465

397466

398467
class NegativeBinomialProbs(GammaPoisson):
468+
r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
469+
and :code:`probs` (:math:`p`). It is implemented as a
470+
:math:`\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)` distribution.
471+
472+
:param total_count: Number of successful trials (:math:`r`).
473+
:param probs: Probability of success for each trial (:math:`p`).
474+
"""
475+
399476
arg_constraints = {
400477
"total_count": constraints.positive,
401478
"probs": constraints.unit_interval,
@@ -416,6 +493,15 @@ def __init__(
416493

417494

418495
class NegativeBinomialLogits(GammaPoisson):
496+
r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
497+
and :code:`logits` (:math:`\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}`). It
498+
is implemented as a :math:`\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))`
499+
distribution.
500+
501+
:param total_count: Number of successful trials.
502+
:param logits: Log-odds of success for each trial (:math:`\ln \frac{p}{1-p}`).
503+
"""
504+
419505
arg_constraints = {
420506
"total_count": constraints.positive,
421507
"logits": constraints.real,
@@ -436,6 +522,14 @@ def __init__(
436522

437523
@validate_sample
438524
def log_prob(self, value: ArrayLike) -> ArrayLike:
525+
r"""If :math:`X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))`, then the log
526+
probability mass function is:
527+
528+
.. math::
529+
\ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p)))
530+
- k \ln(1+\exp(-\mathrm{logits}(p)))
531+
- \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha)
532+
"""
439533
return -(
440534
self.total_count * nn.softplus(self.logits)
441535
+ value * nn.softplus(-self.logits)
@@ -444,8 +538,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
444538

445539

446540
class NegativeBinomial2(GammaPoisson):
447-
"""
448-
Another parameterization of GammaPoisson with `rate` is replaced by `mean`.
541+
r"""If :math:`X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)`, then
542+
:math:`X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})`.
543+
544+
:param numpy.ndarray mean: mean parameter (:math:`\mu`).
545+
:param numpy.ndarray concentration: concentration parameter (:math:`\alpha`).
449546
"""
450547

451548
arg_constraints = {

numpyro/distributions/continuous.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,14 @@ def entropy(self) -> ArrayLike:
572572

573573

574574
class Gamma(Distribution):
575+
r"""Implementation of the `Gamma distribution <https://en.wikipedia.org/wiki/Gamma_distribution>`_,
576+
:math:`\mathrm{Gamma}(\alpha, \lambda)`, where, :math:`\alpha` is the concentration
577+
and :math:`\lambda` is the rate.
578+
579+
:param ArrayLike concentration: concentration parameter :math:`\alpha` (also known as shape parameter).
580+
:param ArrayLike rate: rate parameter :math:`\lambda` (inverse scale parameter).
581+
"""
582+
575583
arg_constraints = {
576584
"concentration": constraints.positive,
577585
"rate": constraints.positive,
@@ -595,12 +603,26 @@ def __init__(
595603
def sample(
596604
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
597605
) -> ArrayLike:
606+
r"""Method to generate samples :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`.
607+
It uses :func:`~jax.random.gamma` under the hood to generate samples.
608+
"""
598609
assert is_prng_key(key)
599610
shape = sample_shape + self.batch_shape + self.event_shape
600611
return random.gamma(key, self.concentration, shape=shape) / self.rate
601612

602613
@validate_sample
603614
def log_prob(self, value: ArrayLike) -> ArrayLike:
615+
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
616+
617+
.. math::
618+
619+
f_X(x\mid \alpha, \lambda) =
620+
\frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)},
621+
\quad x > 0
622+
623+
It uses :func:`~jax.scipy.special.gammaln` to compute the logarithm of the
624+
gamma function.
625+
"""
604626
normalize_term = gammaln(self.concentration) - self.concentration * jnp.log(
605627
self.rate
606628
)
@@ -612,19 +634,58 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
612634

613635
@property
614636
def mean(self) -> ArrayLike:
637+
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
638+
639+
.. math::
640+
\mathbb{E}[X] = \frac{\alpha}{\lambda}
641+
"""
615642
return self.concentration / self.rate
616643

617644
@property
618645
def variance(self) -> ArrayLike:
646+
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
647+
648+
.. math::
649+
\mathrm{Var}[X] = \frac{\alpha}{\lambda^2}
650+
"""
619651
return self.concentration / jnp.power(self.rate, 2)
620652

621653
def cdf(self, x):
654+
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
655+
656+
.. math::
657+
F_X(x \mid \alpha, \lambda) = \frac{1}{\Gamma(\alpha)}
658+
\gamma\left(\alpha, \lambda x\right)
659+
660+
where, :math:`\gamma(\cdot,\cdot)` is the `lower incomplete gamma function <https://en.wikipedia.org/wiki/Incomplete_gamma_function>`_.
661+
This method uses regularized incomplete gamma function,
662+
which is implemented as :func:`~jax.scipy.special.gammainc`.
663+
"""
622664
return gammainc(self.concentration, self.rate * x)
623665

624666
def icdf(self, q: ArrayLike) -> ArrayLike:
667+
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
668+
669+
.. math::
670+
F^{-1}_X(q \mid \alpha, \lambda) = \frac{1}{\lambda}
671+
\gamma^{-1}\left(\alpha, q \Gamma(\alpha)\right)
672+
673+
where, :math:`\gamma^{-1}(\cdot,\cdot)` is the inverse of the lower incomplete gamma function.
674+
This method uses regularized incomplete gamma inverse function,
675+
which is implemented as :func:`~numpyro.distributions.util.gammaincinv`.
676+
"""
625677
return gammaincinv(self.concentration, q) / self.rate
626678

627679
def entropy(self) -> ArrayLike:
680+
r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then
681+
682+
.. math::
683+
H[X] = \alpha - \ln(\lambda) + \ln\Gamma(\alpha)
684+
+ (1 - \alpha) \psi(\alpha)
685+
686+
where, :math:`\psi(\cdot)` is the `digamma function <https://en.wikipedia.org/wiki/Digamma_function>`_.
687+
This methods uses which is implemented as :func:`~jax.scipy.special.digamma`.
688+
"""
628689
return (
629690
self.concentration
630691
- jnp.log(self.rate)

numpyro/distributions/truncated.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -535,13 +535,19 @@ def support(self) -> ConstraintT:
535535
@validate_sample
536536
def log_prob(self, value: ArrayLike) -> ArrayLike:
537537
r"""Logarithmic probability distribution:
538+
538539
Z inequal minus one:
540+
539541
.. math::
540-
(x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1))
542+
543+
\frac{(\alpha + 1)x^\alpha}{b^{\alpha + 1} - a^{\alpha + 1}}
541544
542545
Z equal minus one:
546+
543547
.. math::
544-
(x^\alpha)/(log(b) - log(a))
548+
549+
\frac{x^\alpha}{\log(b) - \log(a)}
550+
545551
Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly.
546552
"""
547553

0 commit comments

Comments
 (0)