From 4e10c512a41ba8408cbf2af045c20e7aca78c7c2 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Thu, 18 Dec 2025 01:10:38 +0500 Subject: [PATCH 1/7] fix: add docstring to `validate_sample` wrapper --- numpyro/distributions/util.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index f690bbecf..c902f3fef 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -780,6 +780,8 @@ def wrapper(self, *args, **kwargs): log_prob = jnp.where(mask, log_prob, -jnp.inf) return log_prob + wrapper.__doc__ = log_prob_fn.__doc__ + return wrapper From ec60ac13fdcd924cba9e965f09cc33eaf37d5739 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Thu, 18 Dec 2025 01:34:45 +0500 Subject: [PATCH 2/7] docs: add detailed mathematical descriptions and docstrings for Gamma distribution methods --- numpyro/distributions/continuous.py | 61 +++++++++++++++++++++++++++++ 1 file changed, 61 insertions(+) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 254a8e60d..f59a45ed1 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -572,6 +572,14 @@ def entropy(self) -> ArrayLike: class Gamma(Distribution): + r"""Implementation of the `Gamma distribution `_, + :math:`\mathrm{Gamma}(\alpha, \lambda)`, where, :math:`\alpha` is the concentration + and :math:`\lambda` is the rate. + + :param ArrayLike concentration: concentration parameter :math:`\alpha` (also known as shape parameter). + :param ArrayLike rate: rate parameter :math:`\lambda` (inverse scale parameter). + """ + arg_constraints = { "concentration": constraints.positive, "rate": constraints.positive, @@ -595,12 +603,26 @@ def __init__( def sample( self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () ) -> ArrayLike: + r"""Method to generate samples :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`. + It uses :func:`~jax.random.gamma` under the hood to generate samples. + """ assert is_prng_key(key) shape = sample_shape + self.batch_shape + self.event_shape return random.gamma(key, self.concentration, shape=shape) / self.rate @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then + + .. math:: + + f_X(x\mid \alpha, \lambda) = + \frac{\lambda^{\alpha} x^{\alpha - 1} e^{-\lambda x}}{\Gamma(\alpha)}, + \quad x > 0 + + It uses :func:`~jax.scipy.special.gammaln` to compute the logarithm of the + gamma function. + """ normalize_term = gammaln(self.concentration) - self.concentration * jnp.log( self.rate ) @@ -612,19 +634,58 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: @property def mean(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then + + .. math:: + \mathbb{E}[X] = \frac{\alpha}{\lambda} + """ return self.concentration / self.rate @property def variance(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then + + .. math:: + \mathrm{Var}[X] = \frac{\alpha}{\lambda^2} + """ return self.concentration / jnp.power(self.rate, 2) def cdf(self, x): + r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then + + .. math:: + F_X(x \mid \alpha, \lambda) = \frac{1}{\Gamma(\alpha)} + \gamma\left(\alpha, \lambda x\right) + + where, :math:`\gamma(\cdot,\cdot)` is the `lower incomplete gamma function `_. + This method uses regularized incomplete gamma function, + which is implemented as :func:`~jax.scipy.special.gammainc`. + """ return gammainc(self.concentration, self.rate * x) def icdf(self, q: ArrayLike) -> ArrayLike: + r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then + + .. math:: + F^{-1}_X(q \mid \alpha, \lambda) = \frac{1}{\lambda} + \gamma^{-1}\left(\alpha, q \Gamma(\alpha)\right) + + where, :math:`\gamma^{-1}(\cdot,\cdot)` is the inverse of the lower incomplete gamma function. + This method uses regularized incomplete gamma inverse function, + which is implemented as :func:`~numpyro.distributions.util.gammaincinv`. + """ return gammaincinv(self.concentration, q) / self.rate def entropy(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{Gamma}(\alpha, \lambda)`, then + + .. math:: + H[X] = \alpha - \ln(\lambda) + \ln\Gamma(\alpha) + + (1 - \alpha) \psi(\alpha) + + where, :math:`\psi(\cdot)` is the `digamma function `_. + This methods uses which is implemented as :func:`~jax.scipy.special.digamma`. + """ return ( self.concentration - jnp.log(self.rate) From bc651e016b8a44debb7b4827cb54a46dafa35e25 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 19 Dec 2025 21:43:13 +0500 Subject: [PATCH 3/7] docs: update `GammaPoisson` class docstrings with detailed mathematical descriptions --- numpyro/distributions/conjugate.py | 42 +++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index 932075eb1..a5f0e86a2 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -326,7 +326,7 @@ class GammaPoisson(Distribution): drawn from a :class:`~numpyro.distributions.Gamma` distribution. :param numpy.ndarray concentration: shape parameter (alpha) of the Gamma distribution. - :param numpy.ndarray rate: rate parameter (beta) for the Gamma distribution. + :param numpy.ndarray rate: rate parameter (rate) for the Gamma distribution. """ arg_constraints = { @@ -352,6 +352,20 @@ def __init__( def sample( self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () ) -> ArrayLike: + r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the sampling + procedure is: + + .. math:: + \begin{align*} + \theta &\sim \mathrm{Gamma}(\alpha, \lambda) \\ + X \mid \theta &\sim \mathrm{Poisson}(\theta) + \end{align*} + + It uses :class:`~numpyro.distributions.continuous.Gamma` to generate samples + from the Gamma distribution and + :class:`~numpyro.distributions.continuous.Poisson` to generate samples from the + Poisson distribution. + """ assert is_prng_key(key) key_gamma, key_poisson = random.split(key) rate = self._gamma.sample(key_gamma, sample_shape) @@ -359,6 +373,12 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the log + probability mass function is: + + .. math:: + p_{X}(k) = \frac{\lambda^\alpha}{(\alpha + k)(1+\lambda)^{\alpha + k}\mathrm{B}(\alpha, k + 1)} + """ post_value = self.concentration + value return ( -betaln(self.concentration, value + 1) @@ -369,13 +389,33 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: @property def mean(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the mean is: + + .. math:: + \mathbb{E}[X] = \frac{\alpha}{\lambda} + """ return self.concentration / self.rate @property def variance(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the variance is: + + .. math:: + \mathrm{Var}[X] = \frac{\alpha}{\lambda^2}(1 + \lambda) + """ return self.concentration / jnp.square(self.rate) * (1 + self.rate) def cdf(self, value: ArrayLike) -> ArrayLike: + r"""If :math:`X \sim \mathrm{GammaPoisson}(\alpha, \lambda)`, then the cumulative + distribution function is: + + .. math:: + F_{X}(x) = \frac{1}{\mathrm{B}(\alpha, x + 1)} + \int_{0}^{\frac{\lambda}{1 + \lambda}} t^{\alpha - 1} (1 - t)^{x} dt + + which is the regularized incomplete beta function. + This implementation uses :func:`~jax.scipy.special.betainc`. + """ bt = betainc(self.concentration, value + 1.0, self.rate / (self.rate + 1.0)) return bt From 4f5c11e3e55a604429eab16098ae99ec6326b63f Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 19 Dec 2025 21:53:20 +0500 Subject: [PATCH 4/7] docs: enhance docstrings for `BetaNegativeBinomial class` with detailed mathematical descriptions --- numpyro/distributions/conjugate.py | 62 +++++++++++++++++++----------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index a5f0e86a2..008f9035a 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -120,19 +120,6 @@ class BetaNegativeBinomial(Distribution): overdispersed count data. It arises as the marginal distribution when integrating out the success probability in a negative binomial model with a beta prior. - .. math:: - - p &\sim \mathrm{Beta}(\alpha, \beta) \\ - X \mid p &\sim \mathrm{NegativeBinomial}(n, p) - - The probability mass function is: - - .. math:: - - P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)} - - where :math:`B(\cdot, \cdot)` is the beta function. - :param numpy.ndarray concentration1: 1st concentration parameter (alpha) for the Beta distribution. :param numpy.ndarray concentration0: 2nd concentration parameter (beta) for the @@ -140,15 +127,6 @@ class BetaNegativeBinomial(Distribution): :param numpy.ndarray n: positive number of successes parameter for the negative binomial distribution. - **Properties** - - - **Mean**: :math:`\frac{n \cdot \alpha}{\beta - 1}` for :math:`\beta > 1`, else undefined. - - **Variance**: for :math:`\beta > 2`, else undefined. - - .. math:: - - \frac{n \cdot \alpha \cdot (n + \beta - 1) \cdot (\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)} - **References** [1] https://en.wikipedia.org/wiki/Beta_negative_binomial_distribution @@ -187,6 +165,20 @@ def __init__( def sample( self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () ) -> ArrayLike: + r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the sampling + procedure is: + + .. math:: + \begin{align*} + p &\sim \mathrm{Beta}(\alpha, \beta) \\ + X \mid p &\sim \mathrm{NegativeBinomial}(n, p) + \end{align*} + + It uses :class:`~numpyro.distributions.continuous.Beta` to generate samples + from the Beta distribution and + :class:`~numpyro.distributions.discrete.NegativeBinomialProbs` to generate samples + from the Negative Binomial distribution. + """ assert is_prng_key(key) key_beta, key_nb = random.split(key) probs = self._beta.sample(key_beta, sample_shape) @@ -194,6 +186,15 @@ def sample( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)`, then the log + probability mass function is: + + .. math:: + P(X = k) = \binom{n + k - 1}{k} \frac{B(\alpha + k, \beta + n)}{B(\alpha, \beta)} + + To ensure differentiability, the binomial coefficient is computed using + gamma functions. + """ return ( gammaln(self.n + value) - gammaln(self.n) @@ -204,6 +205,14 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: @property def mean(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and + :math:`\beta > 1`, then the mean is: + + .. math:: + \mathbb{E}[X] = \frac{n\alpha}{\beta - 1}, + + otherwise, the its undefined. + """ return jnp.where( self.concentration0 > 1, self.n * self.concentration1 / (self.concentration0 - 1), @@ -212,6 +221,15 @@ def mean(self) -> ArrayLike: @property def variance(self) -> ArrayLike: + r"""If :math:`X \sim \mathrm{BetaNegativeBinomial}(\alpha, \beta, n)` and + :math:`\beta > 2`, then the variance is: + + .. math:: + \mathrm{Var}[X] = + \frac{n\alpha (n + \beta - 1)(\alpha + \beta - 1)}{(\beta - 1)^2 \cdot (\beta - 2)}, + + otherwise, the its undefined. + """ alpha = self.concentration1 beta = self.concentration0 n = self.n From c24f47b478c8b8c9d69f6d344f9230c6c9243403 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Fri, 19 Dec 2025 23:08:34 +0500 Subject: [PATCH 5/7] docs: enhance docstrings for Negative Binomial distribution and its parameterizations --- numpyro/distributions/conjugate.py | 45 ++++++++++++++++++++++++++++-- 1 file changed, 42 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/conjugate.py b/numpyro/distributions/conjugate.py index 008f9035a..fca86e000 100644 --- a/numpyro/distributions/conjugate.py +++ b/numpyro/distributions/conjugate.py @@ -444,7 +444,18 @@ def NegativeBinomial( logits: Optional[ArrayLike] = None, *, validate_args: Optional[bool] = None, -): +) -> GammaPoisson: + """Factory function for Negative Binomial distribution. + + :param int total_count: Number of successful trials. + :param Optional[ArrayLike] probs: Probability of success for each trial, by default None + :param Optional[ArrayLike] logits: Log-odds of success for each trial, by default None + :param Optional[bool] validate_args: Whether to validate the parameters, by default None + :return: An instance of :class:`NegativeBinomialProbs` or + :class:`NegativeBinomialLogits` depending on the provided parameters. + :rtype: GammaPoisson + :raises ValueError: If neither :code:`probs` nor :code:`logits` is specified. + """ if probs is not None: return NegativeBinomialProbs(total_count, probs, validate_args=validate_args) elif logits is not None: @@ -454,6 +465,14 @@ def NegativeBinomial( class NegativeBinomialProbs(GammaPoisson): + r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`) + and :code:`probs` (:math:`p`). It is implemented as a + :math:`\displaystyle\mathrm{GammaPoisson}(n, \frac{1}{p} - 1)` distribution. + + :param total_count: Number of successful trials (:math:`r`). + :param probs: Probability of success for each trial (:math:`p`). + """ + arg_constraints = { "total_count": constraints.positive, "probs": constraints.unit_interval, @@ -474,6 +493,15 @@ def __init__( class NegativeBinomialLogits(GammaPoisson): + r"""Negative Binomial distribution parameterized by :code:`total_count` (:math:`r`) + and :code:`logits` (:math:`\displaystyle\mathrm{logits}(p)=\log \frac{p}{1-p}`). It + is implemented as a :math:`\mathrm{GammaPoisson}(n, \exp(-\mathrm{logits}(p)))` + distribution. + + :param total_count: Number of successful trials. + :param logits: Log-odds of success for each trial (:math:`\ln \frac{p}{1-p}`). + """ + arg_constraints = { "total_count": constraints.positive, "logits": constraints.real, @@ -494,6 +522,14 @@ def __init__( @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + r"""If :math:`X \sim \mathrm{NegativeBinomial}(r, \mathrm{logits}(p))`, then the log + probability mass function is: + + .. math:: + \ln P(X = k) = -r \ln(1+\exp(\mathrm{logits}(p))) + - k \ln(1+\exp(-\mathrm{logits}(p))) + - \ln\Gamma(1 + k) - \ln\Gamma(\alpha) + \ln\Gamma(k + \alpha) + """ return -( self.total_count * nn.softplus(self.logits) + value * nn.softplus(-self.logits) @@ -502,8 +538,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: class NegativeBinomial2(GammaPoisson): - """ - Another parameterization of GammaPoisson with `rate` is replaced by `mean`. + r"""If :math:`X \sim \mathrm{NegativeBinomial2}(\mu, \alpha)`, then + :math:`X \sim \mathrm{GammaPoisson}(\alpha, \frac{\alpha}{\mu})`. + + :param numpy.ndarray mean: mean parameter (:math:`\mu`). + :param numpy.ndarray concentration: concentration parameter (:math:`\alpha`). """ arg_constraints = { From 3388363f493ffc0cfa556ca3b1983457b51214df Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 22 Dec 2025 18:36:55 +0500 Subject: [PATCH 6/7] fix: apply `functools.wraps` inside `wrapper` in `numpyro.distributions.utils.validate_sample` --- numpyro/distributions/util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c902f3fef..835d789d1 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -1,7 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 + from collections import namedtuple +import functools as ft from functools import partial, update_wrapper import math import warnings @@ -772,6 +774,7 @@ def __get__(self, instance, obj_type=None): def validate_sample(log_prob_fn): + @ft.wraps(log_prob_fn) def wrapper(self, *args, **kwargs): log_prob = log_prob_fn(self, *args, **kwargs) if self._validate_args: @@ -780,8 +783,6 @@ def wrapper(self, *args, **kwargs): log_prob = jnp.where(mask, log_prob, -jnp.inf) return log_prob - wrapper.__doc__ = log_prob_fn.__doc__ - return wrapper From 847d5b5ef67cb241a2ac77737c1e8fc720db4c20 Mon Sep 17 00:00:00 2001 From: Meesum Qazalbash Date: Mon, 22 Dec 2025 18:38:50 +0500 Subject: [PATCH 7/7] fix: correction of docstring in `numpyro.distributions.truncated.DoublyTruncatedPowerLaw.log_prob` --- numpyro/distributions/truncated.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/truncated.py b/numpyro/distributions/truncated.py index 30ee4a49b..b0e133e39 100644 --- a/numpyro/distributions/truncated.py +++ b/numpyro/distributions/truncated.py @@ -535,13 +535,19 @@ def support(self) -> ConstraintT: @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: r"""Logarithmic probability distribution: + Z inequal minus one: + .. math:: - (x^\alpha) (\alpha + 1)/(b^(\alpha + 1) - a^(\alpha + 1)) + + \frac{(\alpha + 1)x^\alpha}{b^{\alpha + 1} - a^{\alpha + 1}} Z equal minus one: + .. math:: - (x^\alpha)/(log(b) - log(a)) + + \frac{x^\alpha}{\log(b) - \log(a)} + Derivations are calculated by Wolfram Alpha via the Jacobian matrix accordingly. """