Skip to content
Discussion options

You must be logged in to vote

In general, JIT compilation will rearrange the order of operations in your function for efficiency, and this can sometimes change the numerical results of your function. For details and more explanation of this, see FAQ: jit changes the exact numerics of outputs
. In this case, the fact that you're computing jnp.log of an exponentiated quantity (multivariate_normal.pdf) is likely the culprit.

You can achieve a better-behaved version of your function by avoiding taking the log of the exponential in the first place:

def jitless_log_likelihood(x, mu, sigma):
    return jnp.sum(jstats.multivariate_normal.logpdf(x, mean=mu, cov=sigma))

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
0 replies
Answer selected by AlejandroBaron
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants