Skip to content
Discussion options

You must be logged in to vote

Hello,

Short answer, jit 'almost' always has to be the outer transformation.

Long answer :

There are three small mistakes you did in your test.

Asynchronous dispatch

JAX runs everything asynchronously, so in your code the values are not guaranteed to be 'doubled' until you use them.

This means

start1 = timer()
result1 = jitted_vmapped_fn(x_md)
end1 = timer()
# jitted_vmapped_fn might still be running

The corret thing to do is

start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()

For more info async_dispatch

Don't profile the jit time

jitting more complex code takes more time. Functions are jit compiled the first time you run them and the subsequent executio…

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@andremfreitas
Comment options

@jakevdp
Comment options

Answer selected by andremfreitas
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants