@@ -120,35 +120,13 @@ class BetaNegativeBinomial(Distribution):
120120 overdispersed count data. It arises as the marginal distribution when integrating out
121121 the success probability in a negative binomial model with a beta prior.
122122
123- .. math::
124-
125- p &\sim \mathrm{Beta}(\alpha, \beta) \\
126- X \mid p &\sim \mathrm{NegativeBinomial}(n, p)
127-
128- The probability mass function is:
129-
130- .. math::
131-
132- P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}
133-
134- where :math:`B(\cdot, \cdot)` is the beta function.
135-
136123 :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the
137124 Beta distribution.
138125 :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the
139126 Beta distribution.
140127 :param numpy.ndarray n: positive number of successes parameter for the negative
141128 binomial distribution.
142129
143- **Properties**
144-
145- - **Mean**: :math:`\frac{n \cdot \alpha}{\beta - 1}` for :math:`\beta > 1`, else undefined.
146- - **Variance**: for :math:`\beta > 2`, else undefined.
147-
148- .. math::
149-
150- \frac{n \cdot \alpha \cdot (n + \beta - 1) \cdot (\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)}
151-
152130 **References**
153131
154132 [1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution
@@ -187,13 +165,36 @@ def __init__(
187165 def sample (
188166 self , key : jax .dtypes .prng_key , sample_shape : tuple [int , ...] = ()
189167 ) -> ArrayLike :
168+ r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the sampling
169+ procedure is:
170+
171+ .. math::
172+ \begin{align*}
173+ p &\sim \mathrm{Beta}(\alpha, \beta) \\
174+ X \mid p &\sim \mathrm{NegativeBinomial}(n, p)
175+ \end{align*}
176+
177+ It uses :class:`~numpyro.distributions.continuous.Beta` to generate samples
178+ from the Beta distribution and
179+ :class:`~numpyro.distributions.discrete.NegativeBinomialProbs` to generate samples
180+ from the Negative Binomial distribution.
181+ """
190182 assert is_prng_key (key )
191183 key_beta , key_nb = random .split (key )
192184 probs = self ._beta .sample (key_beta , sample_shape )
193185 return NegativeBinomialProbs (total_count = self .n , probs = probs ).sample (key_nb )
194186
195187 @validate_sample
196188 def log_prob (self , value : ArrayLike ) -> ArrayLike :
189+ r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the log
190+ probability mass function is:
191+
192+ .. math::
193+ P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)}
194+
195+ To ensure differentiability, the binomial coefficient is computed using
196+ gamma functions.
197+ """
197198 return (
198199 gammaln (self .n + value )
199200 - gammaln (self .n )
@@ -204,6 +205,14 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
204205
205206 @property
206207 def mean (self ) -> ArrayLike :
208+ r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
209+ :math:`\beta > 1`, then the mean is:
210+
211+ .. math::
212+ \mathbb{E}[X] = \frac{n\alpha}{\beta - 1},
213+
214+ otherwise, the its undefined.
215+ """
207216 return jnp .where (
208217 self .concentration0 > 1 ,
209218 self .n * self .concentration1 / (self .concentration0 - 1 ),
@@ -212,6 +221,15 @@ def mean(self) -> ArrayLike:
212221
213222 @property
214223 def variance (self ) -> ArrayLike :
224+ r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and
225+ :math:`\beta > 2`, then the variance is:
226+
227+ .. math::
228+ \mathrm{Var}[X] =
229+ \frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)},
230+
231+ otherwise, the its undefined.
232+ """
215233 alpha = self .concentration1
216234 beta = self .concentration0
217235 n = self .n
@@ -326,7 +344,7 @@ class GammaPoisson(Distribution):
326344 drawn from a :class:`~numpyro.distributions.Gamma` distribution.
327345
328346 :param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution.
329- :param numpy.ndarray rate: rate parameter (beta ) for the Gamma distribution.
347+ :param numpy.ndarray rate: rate parameter (rate ) for the Gamma distribution.
330348 """
331349
332350 arg_constraints = {
@@ -352,13 +370,33 @@ def __init__(
352370 def sample (
353371 self , key : jax .dtypes .prng_key , sample_shape : tuple [int , ...] = ()
354372 ) -> ArrayLike :
373+ r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the sampling
374+ procedure is:
375+
376+ .. math::
377+ \begin{align*}
378+ \theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\
379+ X \mid \theta &\sim \mathrm{Poisson}(\theta)
380+ \end{align*}
381+
382+ It uses :class:`~numpyro.distributions.continuous.Gamma` to generate samples
383+ from the Gamma distribution and
384+ :class:`~numpyro.distributions.continuous.Poisson` to generate samples from the
385+ Poisson distribution.
386+ """
355387 assert is_prng_key (key )
356388 key_gamma , key_poisson = random .split (key )
357389 rate = self ._gamma .sample (key_gamma , sample_shape )
358390 return Poisson (rate ).sample (key_poisson )
359391
360392 @validate_sample
361393 def log_prob (self , value : ArrayLike ) -> ArrayLike :
394+ r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the log
395+ probability mass function is:
396+
397+ .. math::
398+ p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)}
399+ """
362400 post_value = self .concentration + value
363401 return (
364402 - betaln (self .concentration , value + 1 )
@@ -369,13 +407,33 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
369407
370408 @property
371409 def mean (self ) -> ArrayLike :
410+ r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the mean is:
411+
412+ .. math::
413+ \mathbb{E}[X] = \frac{\alpha}{\lambda}
414+ """
372415 return self .concentration / self .rate
373416
374417 @property
375418 def variance (self ) -> ArrayLike :
419+ r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the variance is:
420+
421+ .. math::
422+ \mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda)
423+ """
376424 return self .concentration / jnp .square (self .rate ) * (1 + self .rate )
377425
378426 def cdf (self , value : ArrayLike ) -> ArrayLike :
427+ r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the cumulative
428+ distribution function is:
429+
430+ .. math::
431+ F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)}
432+ \int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt
433+
434+ which is the regularized incomplete beta function.
435+ This implementation uses :func:`~jax.scipy.special.betainc`.
436+ """
379437 bt = betainc (self .concentration , value + 1.0 , self .rate / (self .rate + 1.0 ))
380438 return bt
381439
@@ -386,7 +444,18 @@ def NegativeBinomial(
386444 logits : Optional [ArrayLike ] = None ,
387445 * ,
388446 validate_args : Optional [bool ] = None ,
389- ):
447+ ) -> GammaPoisson :
448+ """Factory function for Negative Binomial distribution.
449+
450+ :param int total_count: Number of successful trials.
451+ :param Optional[ArrayLike] probs: Probability of success for each trial, by default None
452+ :param Optional[ArrayLike] logits: Log-odds of success for each trial, by default None
453+ :param Optional[bool] validate_args: Whether to validate the parameters, by default None
454+ :return: An instance of :class:`NegativeBinomialProbs` or
455+ :class:`NegativeBinomialLogits` depending on the provided parameters.
456+ :rtype: GammaPoisson
457+ :raises ValueError: If neither :code:`probs` nor :code:`logits` is specified.
458+ """
390459 if probs is not None :
391460 return NegativeBinomialProbs (total_count , probs , validate_args = validate_args )
392461 elif logits is not None :
@@ -396,6 +465,14 @@ def NegativeBinomial(
396465
397466
398467class NegativeBinomialProbs (GammaPoisson ):
468+ r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
469+ and :code:`probs` (:math:`p`). It is implemented as a
470+ :math:`\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)` distribution.
471+
472+ :param total_count: Number of successful trials (:math:`r`).
473+ :param probs: Probability of success for each trial (:math:`p`).
474+ """
475+
399476 arg_constraints = {
400477 "total_count" : constraints .positive ,
401478 "probs" : constraints .unit_interval ,
@@ -416,6 +493,15 @@ def __init__(
416493
417494
418495class NegativeBinomialLogits (GammaPoisson ):
496+ r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`)
497+ and :code:`logits` (:math:`\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}`). It
498+ is implemented as a :math:`\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))`
499+ distribution.
500+
501+ :param total_count: Number of successful trials.
502+ :param logits: Log-odds of success for each trial (:math:`\ln \frac{p}{1-p}`).
503+ """
504+
419505 arg_constraints = {
420506 "total_count" : constraints .positive ,
421507 "logits" : constraints .real ,
@@ -436,6 +522,14 @@ def __init__(
436522
437523 @validate_sample
438524 def log_prob (self , value : ArrayLike ) -> ArrayLike :
525+ r"""If :math:`X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))`, then the log
526+ probability mass function is:
527+
528+ .. math::
529+ \ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p)))
530+ - k \ln(1+\exp(-\mathrm{logits}(p)))
531+ - \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha)
532+ """
439533 return - (
440534 self .total_count * nn .softplus (self .logits )
441535 + value * nn .softplus (- self .logits )
@@ -444,8 +538,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
444538
445539
446540class NegativeBinomial2 (GammaPoisson ):
447- """
448- Another parameterization of GammaPoisson with `rate` is replaced by `mean`.
541+ r"""If :math:`X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)`, then
542+ :math:`X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})`.
543+
544+ :param numpy.ndarray mean: mean parameter (:math:`\mu`).
545+ :param numpy.ndarray concentration: concentration parameter (:math:`\alpha`).
449546 """
450547
451548 arg_constraints = {
0 commit comments