@@ -545,6 +545,27 @@ class LKJ(TransformedDistribution):
545545 When ``concentration < 1``, the distribution favors samples with small determinent. This is
546546 useful when we know a priori that some underlying variables are correlated.
547547
548+ Sample code for using LKJ in the context of multivariate normal sample::
549+
550+ def model(y): # y has dimension N x d
551+ d = y.shape[1]
552+ N = y.shape[0]
553+ # Vector of variances for each of the d variables
554+ theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
555+
556+ concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
557+ corr_mat = numpyro.sample("corr_mat", dist.LKJ(d, concentration))
558+ sigma = jnp.sqrt(theta)
559+ # we can also use a faster formula `cov_mat = jnp.outer(theta, theta) * corr_mat`
560+ cov_mat = jnp.matmul(jnp.matmul(jnp.diag(sigma), corr_mat), jnp.diag(sigma))
561+
562+ # Vector of expectations
563+ mu = jnp.zeros(d)
564+
565+ with numpyro.plate("observations", N):
566+ obs = numpyro.sample("obs", dist.MultivariateNormal(mu, covariance_matrix=cov_mat), obs=y)
567+ return obs
568+
548569 :param int dimension: dimension of the matrices
549570 :param ndarray concentration: concentration/shape parameter of the
550571 distribution (often referred to as eta)
@@ -606,6 +627,28 @@ class LKJCholesky(Distribution):
606627 (hence small determinent). This is useful when we know a priori that some underlying
607628 variables are correlated.
608629
630+ Sample code for using LKJCholesky in the context of multivariate normal sample::
631+
632+ def model(y): # y has dimension N x d
633+ d = y.shape[1]
634+ N = y.shape[0]
635+ # Vector of variances for each of the d variables
636+ theta = numpyro.sample("theta", dist.HalfCauchy(jnp.ones(d)))
637+ # Lower cholesky factor of a correlation matrix
638+ concentration = jnp.ones(1) # Implies a uniform distribution over correlation matrices
639+ L_omega = numpyro.sample("L_omega", dist.LKJCholesky(d, concentration))
640+ # Lower cholesky factor of the covariance matrix
641+ sigma = jnp.sqrt(theta)
642+ # we can also use a faster formula `L_Omega = sigma[..., None] * L_omega`
643+ L_Omega = jnp.matmul(jnp.diag(sigma), L_omega)
644+
645+ # Vector of expectations
646+ mu = jnp.zeros(d)
647+
648+ with numpyro.plate("observations", N):
649+ obs = numpyro.sample("obs", dist.MultivariateNormal(mu, scale_tril=L_Omega), obs=y)
650+ return obs
651+
609652 :param int dimension: dimension of the matrices
610653 :param ndarray concentration: concentration/shape parameter of the
611654 distribution (often referred to as eta)
0 commit comments