Skip to content

Commit 212bccd

Browse files
authored
Add Inverse Wishart Distribution (#2103)
* init * docs and simplify * rtol * feedback 1 * add comment * improve docstrings * feedback * reuse existing code * feedback flip operator * simplify flip operator
1 parent b0c1dab commit 212bccd

File tree

4 files changed

+447
-3
lines changed

4 files changed

+447
-3
lines changed

docs/source/distributions.rst

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,22 @@ InverseGamma
256256
:show-inheritance:
257257
:member-order: bysource
258258

259+
InverseWishart
260+
^^^^^^^^^^^^^^
261+
.. autoclass:: numpyro.distributions.continuous.InverseWishart
262+
:members:
263+
:undoc-members:
264+
:show-inheritance:
265+
:member-order: bysource
266+
267+
InverseWishartCholesky
268+
^^^^^^^^^^^^^^^^^^^^^^
269+
.. autoclass:: numpyro.distributions.continuous.InverseWishartCholesky
270+
:members:
271+
:undoc-members:
272+
:show-inheritance:
273+
:member-order: bysource
274+
259275
Kumaraswamy
260276
^^^^^^^^^^^
261277
.. autoclass:: numpyro.distributions.continuous.Kumaraswamy

numpyro/distributions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@
3838
HalfCauchy,
3939
HalfNormal,
4040
InverseGamma,
41+
InverseWishart,
42+
InverseWishartCholesky,
4143
Kumaraswamy,
4244
Laplace,
4345
Levy,
@@ -169,6 +171,8 @@
169171
"ImproperUniform",
170172
"Independent",
171173
"InverseGamma",
174+
"InverseWishart",
175+
"InverseWishartCholesky",
172176
"Kumaraswamy",
173177
"Laplace",
174178
"LeftTruncatedDistribution",

numpyro/distributions/continuous.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3265,6 +3265,304 @@ def infer_shapes(
32653265
return batch_shape, event_shape
32663266

32673267

3268+
class InverseWishart(TransformedDistribution):
3269+
r"""
3270+
Inverse Wishart distribution for covariance matrices.
3271+
3272+
The Inverse Wishart distribution is the conjugate prior for the covariance matrix
3273+
of a multivariate normal distribution. If :math:`\mathbf{X} \sim W^{-1}(\mathbf{\Psi}, \nu)`,
3274+
then :math:`\mathbf{X}^{-1} \sim W(\mathbf{\Psi}^{-1}, \nu)` (Wishart distribution).
3275+
3276+
.. math::
3277+
3278+
p(\mathbf{X} \mid \mathbf{\Psi}, \nu) =
3279+
\frac{|\mathbf{\Psi}|^{\nu/2}}{2^{\nu p/2} \Gamma_p(\nu/2)}
3280+
|\mathbf{X}|^{-(\nu + p + 1)/2}
3281+
\exp\left( -\frac{1}{2} \mathrm{tr}(\mathbf{\Psi} \mathbf{X}^{-1}) \right)
3282+
3283+
where :math:`p` is the dimension of the matrix, :math:`\nu > p - 1` is the degrees
3284+
of freedom, and :math:`\mathbf{\Psi}` is the positive definite scale matrix.
3285+
3286+
:param concentration: Degrees of freedom parameter (often denoted :math:`\nu`).
3287+
Must be greater than `p - 1` where `p` is the dimension of the scale matrix.
3288+
:param scale_matrix: Positive definite scale matrix :math:`\mathbf{\Psi}`, analogous
3289+
to the inverse rate of a :class:`Gamma` distribution.
3290+
:param rate_matrix: Inverse of the scale matrix, analogous to the rate of a
3291+
:class:`Gamma` distribution.
3292+
:param scale_tril: Cholesky decomposition of the scale matrix.
3293+
3294+
**Properties**
3295+
3296+
- **Mean**: :math:`\frac{\mathbf{\Psi}}{\nu - p - 1}` for :math:`\nu > p + 1`
3297+
- **Mode**: :math:`\frac{\mathbf{\Psi}}{\nu + p + 1}`
3298+
3299+
**References**
3300+
3301+
[1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution
3302+
"""
3303+
3304+
arg_constraints = {
3305+
"concentration": constraints.dependent(is_discrete=False, event_dim=0),
3306+
"scale_matrix": constraints.positive_definite,
3307+
"rate_matrix": constraints.positive_definite,
3308+
"scale_tril": constraints.lower_cholesky,
3309+
}
3310+
support = constraints.positive_definite
3311+
reparametrized_params = [
3312+
"scale_matrix",
3313+
"rate_matrix",
3314+
"scale_tril",
3315+
]
3316+
3317+
def __init__(
3318+
self,
3319+
concentration: ArrayLike,
3320+
scale_matrix: Optional[Array] = None,
3321+
rate_matrix: Optional[Array] = None,
3322+
scale_tril: Optional[Array] = None,
3323+
*,
3324+
validate_args: Optional[bool] = None,
3325+
) -> None:
3326+
base_dist = InverseWishartCholesky(
3327+
concentration,
3328+
scale_matrix,
3329+
rate_matrix,
3330+
scale_tril,
3331+
validate_args=validate_args,
3332+
)
3333+
super().__init__(
3334+
base_dist, CholeskyTransform().inv, validate_args=validate_args
3335+
)
3336+
3337+
@lazy_property
3338+
def concentration(self):
3339+
return self.base_dist.concentration
3340+
3341+
@lazy_property
3342+
def scale_matrix(self):
3343+
return self.base_dist.scale_matrix
3344+
3345+
@lazy_property
3346+
def rate_matrix(self):
3347+
return self.base_dist.rate_matrix
3348+
3349+
@lazy_property
3350+
def scale_tril(self):
3351+
return self.base_dist.scale_tril
3352+
3353+
@lazy_property
3354+
def mean(self) -> ArrayLike:
3355+
# Mean exists only when concentration > p + 1
3356+
p = self.scale_matrix.shape[-1]
3357+
return jnp.where(
3358+
self.concentration[..., None, None] > p + 1,
3359+
self.scale_matrix / (self.concentration[..., None, None] - p - 1),
3360+
jnp.full_like(self.scale_matrix, jnp.nan),
3361+
)
3362+
3363+
@lazy_property
3364+
def mode(self) -> ArrayLike:
3365+
p = self.scale_matrix.shape[-1]
3366+
return self.scale_matrix / (self.concentration[..., None, None] + p + 1)
3367+
3368+
@lazy_property
3369+
def variance(self) -> ArrayLike:
3370+
# Variance of entry (i,j) for nu > p + 3
3371+
# Var(X_ij) = (Psi_ij^2 + Psi_ii * Psi_jj) / ((nu - p - 1)^2 * (nu - p - 3))
3372+
p = self.scale_matrix.shape[-1]
3373+
nu = jnp.expand_dims(self.concentration, axis=(-1, -2))
3374+
psi = self.scale_matrix
3375+
denom = (nu - p - 1) ** 2 * (nu - p - 3)
3376+
psi_ii = jnp.diagonal(psi, axis1=-2, axis2=-1)[..., :, None]
3377+
psi_jj = jnp.diagonal(psi, axis1=-2, axis2=-1)[..., None, :]
3378+
var = (psi**2 + psi_ii * psi_jj) / denom
3379+
return jnp.where(nu > p + 3, var, jnp.full_like(var, jnp.nan))
3380+
3381+
@staticmethod
3382+
def infer_shapes(
3383+
concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None
3384+
):
3385+
return InverseWishartCholesky.infer_shapes(
3386+
concentration, scale_matrix, rate_matrix, scale_tril
3387+
)
3388+
3389+
3390+
class InverseWishartCholesky(Distribution):
3391+
r"""
3392+
Cholesky factor of an Inverse Wishart distribution for covariance matrices.
3393+
3394+
This distribution samples the Cholesky factor :math:`\mathbf{L}` such that
3395+
:math:`\mathbf{X} = \mathbf{L} \mathbf{L}^T \sim W^{-1}(\mathbf{\Psi}, \nu)`.
3396+
3397+
:param concentration: Degrees of freedom parameter (often denoted :math:`\nu`).
3398+
Must be greater than `p - 1` where `p` is the dimension of the scale matrix.
3399+
:param scale_matrix: Positive definite scale matrix :math:`\mathbf{\Psi}`, analogous
3400+
to the inverse rate of a :class:`Gamma` distribution.
3401+
:param rate_matrix: Inverse of the scale matrix, analogous to the rate of a
3402+
:class:`Gamma` distribution.
3403+
:param scale_tril: Cholesky decomposition of the scale matrix.
3404+
3405+
**References**
3406+
3407+
[1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution
3408+
"""
3409+
3410+
arg_constraints = {
3411+
"concentration": constraints.dependent(is_discrete=False, event_dim=0),
3412+
"scale_matrix": constraints.positive_definite,
3413+
"rate_matrix": constraints.positive_definite,
3414+
"scale_tril": constraints.lower_cholesky,
3415+
}
3416+
support = constraints.lower_cholesky
3417+
reparametrized_params = [
3418+
"scale_matrix",
3419+
"rate_matrix",
3420+
"scale_tril",
3421+
]
3422+
3423+
def __init__(
3424+
self,
3425+
concentration: ArrayLike,
3426+
scale_matrix: Optional[Array] = None,
3427+
rate_matrix: Optional[Array] = None,
3428+
scale_tril: Optional[Array] = None,
3429+
*,
3430+
validate_args: Optional[bool] = None,
3431+
) -> None:
3432+
assert_one_of(
3433+
scale_matrix=scale_matrix,
3434+
rate_matrix=rate_matrix,
3435+
scale_tril=scale_tril,
3436+
)
3437+
concentration = jnp.asarray(concentration)[..., None, None]
3438+
if scale_matrix is not None:
3439+
concentration, self.scale_matrix = promote_shapes(
3440+
concentration, scale_matrix
3441+
)
3442+
self.scale_tril = jnp.linalg.cholesky(self.scale_matrix)
3443+
elif rate_matrix is not None:
3444+
concentration, self.rate_matrix = promote_shapes(concentration, rate_matrix)
3445+
self.scale_tril = cholesky_of_inverse(self.rate_matrix)
3446+
elif scale_tril is not None:
3447+
concentration, self.scale_tril = promote_shapes(
3448+
concentration, jnp.asarray(scale_tril)
3449+
)
3450+
batch_shape = lax.broadcast_shapes(
3451+
jnp.shape(concentration)[:-2], jnp.shape(self.scale_tril)[:-2]
3452+
)
3453+
event_shape = jnp.shape(self.scale_tril)[-2:]
3454+
self.concentration = concentration[..., 0, 0]
3455+
super().__init__(
3456+
batch_shape=batch_shape,
3457+
event_shape=event_shape,
3458+
validate_args=validate_args,
3459+
)
3460+
3461+
@validate_sample
3462+
def log_prob(self, value: ArrayLike) -> ArrayLike:
3463+
# L = value (Cholesky factor), X = L @ L^T ~ InverseWishart(Psi, nu)
3464+
# log p(X) = (nu/2) log|Psi| - (nu*p/2) log(2) - log Gamma_p(nu/2)
3465+
# - ((nu+p+1)/2) log|X| - tr(Psi @ X^{-1}) / 2
3466+
# Trace trick: tr(Psi @ X^{-1}) = ||L^{-1} @ scale_tril||_F^2
3467+
x = solve_triangular(*jnp.broadcast_arrays(value, self.scale_tril), lower=True)
3468+
trace = jnp.square(x).sum(axis=(-1, -2))
3469+
3470+
p = value.shape[-1]
3471+
log_diag = jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1))
3472+
return (
3473+
self.concentration * tri_logabsdet(self.scale_tril) # (nu/2) log|Psi|
3474+
+ p * (1 - self.concentration / 2) * jnp.log(2) # normalization
3475+
- multigammaln(self.concentration / 2, p)
3476+
+ jnp.sum(
3477+
(-self.concentration[..., None] - 1 - jnp.arange(p)) * log_diag,
3478+
axis=-1,
3479+
)
3480+
- trace / 2
3481+
)
3482+
3483+
@lazy_property
3484+
def scale_matrix(self):
3485+
return jnp.matmul(self.scale_tril, self.scale_tril.mT)
3486+
3487+
@lazy_property
3488+
def rate_matrix(self):
3489+
identity = jnp.broadcast_to(
3490+
jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape
3491+
)
3492+
return cho_solve((self.scale_tril, True), identity)
3493+
3494+
def sample(
3495+
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
3496+
) -> ArrayLike:
3497+
assert is_prng_key(key)
3498+
# Sample from standard InverseWishartCholesky using Bartlett decomposition
3499+
# Ref: https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
3500+
rng_diag, rng_offdiag = random.split(key)
3501+
latent = jnp.zeros(sample_shape + self.batch_shape + self.event_shape)
3502+
p = self.event_shape[-1]
3503+
# Inverse Wishart Bartlett: nu - p + 1, nu - p + 2, ..., nu - 1, nu
3504+
i = jnp.arange(p)
3505+
latent = latent.at[..., i, i].set(
3506+
jnp.sqrt(
3507+
random.chisquare(
3508+
rng_diag,
3509+
self.concentration[..., None] + i - p + 1,
3510+
latent.shape[:-1],
3511+
)
3512+
)
3513+
)
3514+
i, j = jnp.tril_indices(p, -1)
3515+
latent = latent.at[..., i, j].set(
3516+
random.normal(rng_offdiag, latent.shape[:-2] + (i.size,))
3517+
)
3518+
# Get Cholesky of InverseWishart(I) by inverting latent
3519+
identity = jnp.broadcast_to(jnp.eye(p), latent.shape)
3520+
L_inv_std = solve_triangular(latent, identity, lower=True)
3521+
3522+
# Transform to InverseWishart(Psi): L = scale_tril @ L_inv_std
3523+
return jnp.matmul(self.scale_tril, L_inv_std)
3524+
3525+
@lazy_property
3526+
def mean(self) -> ArrayLike:
3527+
# Approximate: chol(E[X]) where E[X] = Psi / (nu - p - 1) for nu > p + 1
3528+
p = self.scale_tril.shape[-1]
3529+
mean_x = jnp.where(
3530+
self.concentration[..., None, None] > p + 1,
3531+
self.scale_matrix / (self.concentration[..., None, None] - p - 1),
3532+
jnp.full_like(self.scale_matrix, jnp.nan),
3533+
)
3534+
return jnp.linalg.cholesky(
3535+
jnp.where(jnp.isnan(mean_x), jnp.eye(p), mean_x)
3536+
) * jnp.where(
3537+
self.concentration[..., None, None] > p + 1,
3538+
jnp.ones_like(mean_x),
3539+
jnp.full_like(mean_x, jnp.nan),
3540+
)
3541+
3542+
@lazy_property
3543+
def variance(self) -> ArrayLike:
3544+
# Variance of Cholesky factor is complex; return NaN for now
3545+
return jnp.full(self.batch_shape + self.event_shape, jnp.nan)
3546+
3547+
@staticmethod
3548+
def infer_shapes(
3549+
concentration: tuple[int, ...] = (),
3550+
scale_matrix: Optional[tuple[int, ...]] = None,
3551+
rate_matrix: Optional[tuple[int, ...]] = None,
3552+
scale_tril: Optional[tuple[int, ...]] = None,
3553+
):
3554+
assert_one_of(
3555+
scale_matrix=scale_matrix,
3556+
rate_matrix=rate_matrix,
3557+
scale_tril=scale_tril,
3558+
)
3559+
for matrix in [scale_matrix, rate_matrix, scale_tril]:
3560+
if matrix is not None:
3561+
batch_shape = lax.broadcast_shapes(concentration, matrix[:-2])
3562+
event_shape = matrix[-2:]
3563+
return batch_shape, event_shape
3564+
3565+
32683566
class Levy(Distribution):
32693567
r"""Lévy distribution is a special case of Lévy alpha-stable distribution.
32703568
Its probability density function is given by,

0 commit comments

Comments
 (0)