@@ -3265,6 +3265,304 @@ def infer_shapes(
32653265 return batch_shape , event_shape
32663266
32673267
3268+ class InverseWishart (TransformedDistribution ):
3269+ r"""
3270+ Inverse Wishart distribution for covariance matrices.
3271+
3272+ The Inverse Wishart distribution is the conjugate prior for the covariance matrix
3273+ of a multivariate normal distribution. If :math:`\mathbf{X} \sim W^{-1}(\mathbf{\Psi}, \nu)`,
3274+ then :math:`\mathbf{X}^{-1} \sim W(\mathbf{\Psi}^{-1}, \nu)` (Wishart distribution).
3275+
3276+ .. math::
3277+
3278+ p(\mathbf{X} \mid \mathbf{\Psi}, \nu) =
3279+ \frac{|\mathbf{\Psi}|^{\nu/2}}{2^{\nu p/2} \Gamma_p(\nu/2)}
3280+ |\mathbf{X}|^{-(\nu + p + 1)/2}
3281+ \exp\left( -\frac{1}{2} \mathrm{tr}(\mathbf{\Psi} \mathbf{X}^{-1}) \right)
3282+
3283+ where :math:`p` is the dimension of the matrix, :math:`\nu > p - 1` is the degrees
3284+ of freedom, and :math:`\mathbf{\Psi}` is the positive definite scale matrix.
3285+
3286+ :param concentration: Degrees of freedom parameter (often denoted :math:`\nu`).
3287+ Must be greater than `p - 1` where `p` is the dimension of the scale matrix.
3288+ :param scale_matrix: Positive definite scale matrix :math:`\mathbf{\Psi}`, analogous
3289+ to the inverse rate of a :class:`Gamma` distribution.
3290+ :param rate_matrix: Inverse of the scale matrix, analogous to the rate of a
3291+ :class:`Gamma` distribution.
3292+ :param scale_tril: Cholesky decomposition of the scale matrix.
3293+
3294+ **Properties**
3295+
3296+ - **Mean**: :math:`\frac{\mathbf{\Psi}}{\nu - p - 1}` for :math:`\nu > p + 1`
3297+ - **Mode**: :math:`\frac{\mathbf{\Psi}}{\nu + p + 1}`
3298+
3299+ **References**
3300+
3301+ [1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution
3302+ """
3303+
3304+ arg_constraints = {
3305+ "concentration" : constraints .dependent (is_discrete = False , event_dim = 0 ),
3306+ "scale_matrix" : constraints .positive_definite ,
3307+ "rate_matrix" : constraints .positive_definite ,
3308+ "scale_tril" : constraints .lower_cholesky ,
3309+ }
3310+ support = constraints .positive_definite
3311+ reparametrized_params = [
3312+ "scale_matrix" ,
3313+ "rate_matrix" ,
3314+ "scale_tril" ,
3315+ ]
3316+
3317+ def __init__ (
3318+ self ,
3319+ concentration : ArrayLike ,
3320+ scale_matrix : Optional [Array ] = None ,
3321+ rate_matrix : Optional [Array ] = None ,
3322+ scale_tril : Optional [Array ] = None ,
3323+ * ,
3324+ validate_args : Optional [bool ] = None ,
3325+ ) -> None :
3326+ base_dist = InverseWishartCholesky (
3327+ concentration ,
3328+ scale_matrix ,
3329+ rate_matrix ,
3330+ scale_tril ,
3331+ validate_args = validate_args ,
3332+ )
3333+ super ().__init__ (
3334+ base_dist , CholeskyTransform ().inv , validate_args = validate_args
3335+ )
3336+
3337+ @lazy_property
3338+ def concentration (self ):
3339+ return self .base_dist .concentration
3340+
3341+ @lazy_property
3342+ def scale_matrix (self ):
3343+ return self .base_dist .scale_matrix
3344+
3345+ @lazy_property
3346+ def rate_matrix (self ):
3347+ return self .base_dist .rate_matrix
3348+
3349+ @lazy_property
3350+ def scale_tril (self ):
3351+ return self .base_dist .scale_tril
3352+
3353+ @lazy_property
3354+ def mean (self ) -> ArrayLike :
3355+ # Mean exists only when concentration > p + 1
3356+ p = self .scale_matrix .shape [- 1 ]
3357+ return jnp .where (
3358+ self .concentration [..., None , None ] > p + 1 ,
3359+ self .scale_matrix / (self .concentration [..., None , None ] - p - 1 ),
3360+ jnp .full_like (self .scale_matrix , jnp .nan ),
3361+ )
3362+
3363+ @lazy_property
3364+ def mode (self ) -> ArrayLike :
3365+ p = self .scale_matrix .shape [- 1 ]
3366+ return self .scale_matrix / (self .concentration [..., None , None ] + p + 1 )
3367+
3368+ @lazy_property
3369+ def variance (self ) -> ArrayLike :
3370+ # Variance of entry (i,j) for nu > p + 3
3371+ # Var(X_ij) = (Psi_ij^2 + Psi_ii * Psi_jj) / ((nu - p - 1)^2 * (nu - p - 3))
3372+ p = self .scale_matrix .shape [- 1 ]
3373+ nu = jnp .expand_dims (self .concentration , axis = (- 1 , - 2 ))
3374+ psi = self .scale_matrix
3375+ denom = (nu - p - 1 ) ** 2 * (nu - p - 3 )
3376+ psi_ii = jnp .diagonal (psi , axis1 = - 2 , axis2 = - 1 )[..., :, None ]
3377+ psi_jj = jnp .diagonal (psi , axis1 = - 2 , axis2 = - 1 )[..., None , :]
3378+ var = (psi ** 2 + psi_ii * psi_jj ) / denom
3379+ return jnp .where (nu > p + 3 , var , jnp .full_like (var , jnp .nan ))
3380+
3381+ @staticmethod
3382+ def infer_shapes (
3383+ concentration = (), scale_matrix = None , rate_matrix = None , scale_tril = None
3384+ ):
3385+ return InverseWishartCholesky .infer_shapes (
3386+ concentration , scale_matrix , rate_matrix , scale_tril
3387+ )
3388+
3389+
3390+ class InverseWishartCholesky (Distribution ):
3391+ r"""
3392+ Cholesky factor of an Inverse Wishart distribution for covariance matrices.
3393+
3394+ This distribution samples the Cholesky factor :math:`\mathbf{L}` such that
3395+ :math:`\mathbf{X} = \mathbf{L} \mathbf{L}^T \sim W^{-1}(\mathbf{\Psi}, \nu)`.
3396+
3397+ :param concentration: Degrees of freedom parameter (often denoted :math:`\nu`).
3398+ Must be greater than `p - 1` where `p` is the dimension of the scale matrix.
3399+ :param scale_matrix: Positive definite scale matrix :math:`\mathbf{\Psi}`, analogous
3400+ to the inverse rate of a :class:`Gamma` distribution.
3401+ :param rate_matrix: Inverse of the scale matrix, analogous to the rate of a
3402+ :class:`Gamma` distribution.
3403+ :param scale_tril: Cholesky decomposition of the scale matrix.
3404+
3405+ **References**
3406+
3407+ [1] https://en.wikipedia.org/wiki/Inverse-Wishart_distribution
3408+ """
3409+
3410+ arg_constraints = {
3411+ "concentration" : constraints .dependent (is_discrete = False , event_dim = 0 ),
3412+ "scale_matrix" : constraints .positive_definite ,
3413+ "rate_matrix" : constraints .positive_definite ,
3414+ "scale_tril" : constraints .lower_cholesky ,
3415+ }
3416+ support = constraints .lower_cholesky
3417+ reparametrized_params = [
3418+ "scale_matrix" ,
3419+ "rate_matrix" ,
3420+ "scale_tril" ,
3421+ ]
3422+
3423+ def __init__ (
3424+ self ,
3425+ concentration : ArrayLike ,
3426+ scale_matrix : Optional [Array ] = None ,
3427+ rate_matrix : Optional [Array ] = None ,
3428+ scale_tril : Optional [Array ] = None ,
3429+ * ,
3430+ validate_args : Optional [bool ] = None ,
3431+ ) -> None :
3432+ assert_one_of (
3433+ scale_matrix = scale_matrix ,
3434+ rate_matrix = rate_matrix ,
3435+ scale_tril = scale_tril ,
3436+ )
3437+ concentration = jnp .asarray (concentration )[..., None , None ]
3438+ if scale_matrix is not None :
3439+ concentration , self .scale_matrix = promote_shapes (
3440+ concentration , scale_matrix
3441+ )
3442+ self .scale_tril = jnp .linalg .cholesky (self .scale_matrix )
3443+ elif rate_matrix is not None :
3444+ concentration , self .rate_matrix = promote_shapes (concentration , rate_matrix )
3445+ self .scale_tril = cholesky_of_inverse (self .rate_matrix )
3446+ elif scale_tril is not None :
3447+ concentration , self .scale_tril = promote_shapes (
3448+ concentration , jnp .asarray (scale_tril )
3449+ )
3450+ batch_shape = lax .broadcast_shapes (
3451+ jnp .shape (concentration )[:- 2 ], jnp .shape (self .scale_tril )[:- 2 ]
3452+ )
3453+ event_shape = jnp .shape (self .scale_tril )[- 2 :]
3454+ self .concentration = concentration [..., 0 , 0 ]
3455+ super ().__init__ (
3456+ batch_shape = batch_shape ,
3457+ event_shape = event_shape ,
3458+ validate_args = validate_args ,
3459+ )
3460+
3461+ @validate_sample
3462+ def log_prob (self , value : ArrayLike ) -> ArrayLike :
3463+ # L = value (Cholesky factor), X = L @ L^T ~ InverseWishart(Psi, nu)
3464+ # log p(X) = (nu/2) log|Psi| - (nu*p/2) log(2) - log Gamma_p(nu/2)
3465+ # - ((nu+p+1)/2) log|X| - tr(Psi @ X^{-1}) / 2
3466+ # Trace trick: tr(Psi @ X^{-1}) = ||L^{-1} @ scale_tril||_F^2
3467+ x = solve_triangular (* jnp .broadcast_arrays (value , self .scale_tril ), lower = True )
3468+ trace = jnp .square (x ).sum (axis = (- 1 , - 2 ))
3469+
3470+ p = value .shape [- 1 ]
3471+ log_diag = jnp .log (jnp .diagonal (value , axis1 = - 2 , axis2 = - 1 ))
3472+ return (
3473+ self .concentration * tri_logabsdet (self .scale_tril ) # (nu/2) log|Psi|
3474+ + p * (1 - self .concentration / 2 ) * jnp .log (2 ) # normalization
3475+ - multigammaln (self .concentration / 2 , p )
3476+ + jnp .sum (
3477+ (- self .concentration [..., None ] - 1 - jnp .arange (p )) * log_diag ,
3478+ axis = - 1 ,
3479+ )
3480+ - trace / 2
3481+ )
3482+
3483+ @lazy_property
3484+ def scale_matrix (self ):
3485+ return jnp .matmul (self .scale_tril , self .scale_tril .mT )
3486+
3487+ @lazy_property
3488+ def rate_matrix (self ):
3489+ identity = jnp .broadcast_to (
3490+ jnp .eye (self .scale_tril .shape [- 1 ]), self .scale_tril .shape
3491+ )
3492+ return cho_solve ((self .scale_tril , True ), identity )
3493+
3494+ def sample (
3495+ self , key : jax .dtypes .prng_key , sample_shape : tuple [int , ...] = ()
3496+ ) -> ArrayLike :
3497+ assert is_prng_key (key )
3498+ # Sample from standard InverseWishartCholesky using Bartlett decomposition
3499+ # Ref: https://nbviewer.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril
3500+ rng_diag , rng_offdiag = random .split (key )
3501+ latent = jnp .zeros (sample_shape + self .batch_shape + self .event_shape )
3502+ p = self .event_shape [- 1 ]
3503+ # Inverse Wishart Bartlett: nu - p + 1, nu - p + 2, ..., nu - 1, nu
3504+ i = jnp .arange (p )
3505+ latent = latent .at [..., i , i ].set (
3506+ jnp .sqrt (
3507+ random .chisquare (
3508+ rng_diag ,
3509+ self .concentration [..., None ] + i - p + 1 ,
3510+ latent .shape [:- 1 ],
3511+ )
3512+ )
3513+ )
3514+ i , j = jnp .tril_indices (p , - 1 )
3515+ latent = latent .at [..., i , j ].set (
3516+ random .normal (rng_offdiag , latent .shape [:- 2 ] + (i .size ,))
3517+ )
3518+ # Get Cholesky of InverseWishart(I) by inverting latent
3519+ identity = jnp .broadcast_to (jnp .eye (p ), latent .shape )
3520+ L_inv_std = solve_triangular (latent , identity , lower = True )
3521+
3522+ # Transform to InverseWishart(Psi): L = scale_tril @ L_inv_std
3523+ return jnp .matmul (self .scale_tril , L_inv_std )
3524+
3525+ @lazy_property
3526+ def mean (self ) -> ArrayLike :
3527+ # Approximate: chol(E[X]) where E[X] = Psi / (nu - p - 1) for nu > p + 1
3528+ p = self .scale_tril .shape [- 1 ]
3529+ mean_x = jnp .where (
3530+ self .concentration [..., None , None ] > p + 1 ,
3531+ self .scale_matrix / (self .concentration [..., None , None ] - p - 1 ),
3532+ jnp .full_like (self .scale_matrix , jnp .nan ),
3533+ )
3534+ return jnp .linalg .cholesky (
3535+ jnp .where (jnp .isnan (mean_x ), jnp .eye (p ), mean_x )
3536+ ) * jnp .where (
3537+ self .concentration [..., None , None ] > p + 1 ,
3538+ jnp .ones_like (mean_x ),
3539+ jnp .full_like (mean_x , jnp .nan ),
3540+ )
3541+
3542+ @lazy_property
3543+ def variance (self ) -> ArrayLike :
3544+ # Variance of Cholesky factor is complex; return NaN for now
3545+ return jnp .full (self .batch_shape + self .event_shape , jnp .nan )
3546+
3547+ @staticmethod
3548+ def infer_shapes (
3549+ concentration : tuple [int , ...] = (),
3550+ scale_matrix : Optional [tuple [int , ...]] = None ,
3551+ rate_matrix : Optional [tuple [int , ...]] = None ,
3552+ scale_tril : Optional [tuple [int , ...]] = None ,
3553+ ):
3554+ assert_one_of (
3555+ scale_matrix = scale_matrix ,
3556+ rate_matrix = rate_matrix ,
3557+ scale_tril = scale_tril ,
3558+ )
3559+ for matrix in [scale_matrix , rate_matrix , scale_tril ]:
3560+ if matrix is not None :
3561+ batch_shape = lax .broadcast_shapes (concentration , matrix [:- 2 ])
3562+ event_shape = matrix [- 2 :]
3563+ return batch_shape , event_shape
3564+
3565+
32683566class Levy (Distribution ):
32693567 r"""Lévy distribution is a special case of Lévy alpha-stable distribution.
32703568 Its probability density function is given by,
0 commit comments