Replies: 1 comment
-
Insofar as However, consider using TFP on JAX's MVNChol implementation. import tensorflow_probability.substrates.jax as tfp
import jax.numpy as jnp
tfd = tfp.distributions
mean = jnp.zeros(5)
cov_chol = jnp.linalg.cholesky(jnp.eye(5))
dist = tfd.MultivariateNormalTriL(loc=mean, scale_tril=cov_chol)
dist.sample(seed=...)
dist.log_prob(...) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I have code where I frequently call
jax.scipy.stats.multivariate_normal.logpdf(x, mean, cov)
with different values ofx
but with constantmean
andcov
. Obviously, it would be much more efficient if I could computer the cholesky factorization ofcov
once up front rather than havelogpdf
calculate it repeatedly under the hood. Maybe I'm missing it but it doesn't seem possible to cache the factored matrix? Is this outside the scope of what jax is trying to provide?Beta Was this translation helpful? Give feedback.
All reactions