jit(grad(f))
isn't cached?
#7871
Answered
by
cgarciae
josephrocca
asked this question in
Q&A
-
Is it expected behavior that repeatedly calling
import jax
from jax import grad, jit
import jax.numpy as np f = lambda x: np.mean(np.dot(x, x.T))
jit_grad_f = jit(grad(f)) # cache/warm-up
a = np.ones([2000,2000])
jit(grad(f))(a)
jit_grad_f(a) %timeit jit(grad(f))(a).block_until_ready()
# 28.2 ms ± 351 µs per loop (mean ± std. dev. of 7 runs, 10 loops each) %timeit jit_grad_f(a).block_until_ready()
# 5.51 ms ± 226 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) |
Beta Was this translation helpful? Give feedback.
Answered by
cgarciae
Sep 10, 2021
Replies: 1 comment 3 replies
-
You need to call |
Beta Was this translation helpful? Give feedback.
3 replies
Answer selected by
josephrocca
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You need to call
.block_until_ready()
to benchmark properly:https://jax.readthedocs.io/en/latest/async_dispatch.html