Skip to content
Discussion options

You must be logged in to vote

The reason this takes a long time to run is because you are timing the tracing and compiling of the function. Once the compilation is finished, subsequent runs will be fast:

import jax
import jax.numpy as jnp
from jax import jit

@jit
def foo(a):
  for i in range(32):
    for j in range(32):
      s = a[0]
  return 0

 
a = jnp.zeros(128)
%time foo(a).block_until_ready()  # trace, compile & run
# CPU times: user 2.38 s, sys: 28.2 ms, total: 2.41 s
# Wall time: 2.4 s
%time foo(a).block_until_ready()  # run the compiled function
# CPU times: user 1.03 ms, sys: 0 ns, total: 1.03 ms
# Wall time: 827 µs

Why is compilation so slow here? Well, jax.jit effectively unrolls all Python loops in the …

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
2 replies
@GaoyuanWu
Comment options

@jakevdp
Comment options

Comment options

You must be logged in to vote
0 replies
Answer selected by sharadmv
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
5 participants
Converted from issue

This discussion was converted from issue #4436 on October 02, 2020 18:08.