Why not using @jit yields -inf value but using it doesn't? #12132
-
In my CPU, the following code yields
Why is this happening? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
In the jitted version, some operations got fused or eliminated, so import jax
import jax.numpy as jnp
import jax.scipy.stats as jstats
try:
from jax._src.ad_checkpoint import _optimization_barrier
except:
from jax._src.lax.control_flow.remat_impl import _optimization_barrier
opt = jax.lib.xla_extension.HloPrintOptions.short_parsable()
opt.print_large_constants = False
def f(x, mu, sigma):
val = jstats.multivariate_normal.pdf(x, mean=mu, cov=sigma)
val = _optimization_barrier(val) # (*)
return jnp.log(val)
key = jax.random.PRNGKey(0)
M = 500
x = jax.random.normal(key, (10, M))
mu = jnp.zeros((M,))
sigma = jnp.identity(M)
print(jax.jit(f).lower(x, mu, sigma).compile().compiler_ir()[0].to_string(opt)) Comparing the hlo module generated by ...
fused_computation {
...
param_0.3 = f64[] parameter(0)
broadcast.14 = f64[10]{0} broadcast(param_0.3), dimensions={}
subtract.0 = f64[10]{0} subtract(add.3, broadcast.14)
exponential.0 = f64[10]{0} exponential(subtract.0) /////
ROOT log.0 = f64[10]{0} log(exponential.0) /////
}
... ...
fused_computation {
...
param_0.1 = f64[] parameter(0)
broadcast.14 = f64[10]{0} broadcast(param_0.1), dimensions={}
ROOT subtract.0 = f64[10]{0} subtract(add.3, broadcast.14)
}
... |
Beta Was this translation helpful? Give feedback.
-
In general, JIT compilation will rearrange the order of operations in your function for efficiency, and this can sometimes change the numerical results of your function. For details and more explanation of this, see FAQ: jit changes the exact numerics of outputs You can achieve a better-behaved version of your function by avoiding taking the log of the exponential in the first place: def jitless_log_likelihood(x, mu, sigma):
return jnp.sum(jstats.multivariate_normal.logpdf(x, mean=mu, cov=sigma)) |
Beta Was this translation helpful? Give feedback.
In general, JIT compilation will rearrange the order of operations in your function for efficiency, and this can sometimes change the numerical results of your function. For details and more explanation of this, see FAQ: jit changes the exact numerics of outputs
. In this case, the fact that you're computing
jnp.log
of an exponentiated quantity (multivariate_normal.pdf
) is likely the culprit.You can achieve a better-behaved version of your function by avoiding taking the log of the exponential in the first place: