Skip to content
Discussion options

You must be logged in to vote

Try putting jit on the outside of grad so that we can push more of the computation to XLA:

In [4]: %timeit f(w).block_until_ready()
24.1 µs ± 13.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [5]: timeit grad(f)(w).block_until_ready()
3.27 ms ± 198 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [6]: grad_f = jit(grad(f))

In [7]: timeit grad_f(w).block_until_ready()
32.2 µs ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)

Replies: 1 comment 4 replies

Comment options

You must be logged in to vote
4 replies
@fbartolic
Comment options

@mattjj
Comment options

@fbartolic
Comment options

@patrick-kidger
Comment options

Answer selected by fbartolic
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants