jitting a vmapped function or vmapping a jitted function #20505
-
Which one is better practice ?
The vmap of a jitted function consistently performed better (almost double speedup compared to jitting a vmapped function):
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Hello, Short answer, Long answer : There are three small mistakes you did in your test. Asynchronous dispatchJAX runs everything asynchronously, so in your code the values are not guaranteed to be 'doubled' until you use them. This means start1 = timer()
result1 = jitted_vmapped_fn(x_md)
end1 = timer()
# jitted_vmapped_fn might still be running The corret thing to do is start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer() For more info async_dispatch Don't profile the jit timejitting more complex code takes more time. Functions are jit compiled the first time you run them and the subsequent execution run much faster for example start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()
print(f"Time for jit vmapped function: {end1-start1}")
start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()
print(f"Time for jit vmapped function second time: {end1-start1}") Give this result Time for jit vmapped function: 2.795020341873169
Time for jit vmapped function second time: 0.01996016502380371 You see much much faster !! Don't include cpu copy timeYou are using numpy (CPU) arrays so you are including the copy from CPU to GPU time do this instead # x_md = np.random.rand(10240, 1000) # CPU array
# x_md = jnp.array(x_md) # GPU array
x_md = jax.random.normal(jax.random.PRNGKey(0), (10240, 1000)) # Array directly created on GPU
start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()
print(f"Time for jit vmapped function: {end1-start1}")
start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()
print(f"Time for jit vmapped function second time: {end1-start1}") This gives Time for jit vmapped function: 0.020798683166503906
Time for jit vmapped function second time: 0.0008347034454345703 So much faster !! full compare x_md = jax.random.normal(jax.random.PRNGKey(0), (10240, 1000)) # Array directly created on GPU
start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()
print(f"Time for jit vmapped function: {end1-start1}")
start1 = timer()
result1 = jitted_vmapped_fn(x_md).block_until_ready()
end1 = timer()
print(f"Time for jit vmapped function second time: {end1-start1}")
jitted_fn = jax.jit(fn)
vmapped_jitted_fn = jax.vmap(jitted_fn)
start2 = timer()
result2 = vmapped_jitted_fn(x_md).block_until_ready()
end2 = timer()
print(f"Time for vmap jitted function: {end2-start2}")
start2 = timer()
result2 = vmapped_jitted_fn(x_md).block_until_ready()
end2 = timer()
print(f"Time for vmap jitted function second time : {end2-start2}") Results : Time for jit vmapped function: 0.022216796875
Time for jit vmapped function second time: 0.00091552734375
Time for vmap jitted function: 0.016040325164794922
Time for vmap jitted function second time : 0.0011944770812988281 jitted vmap is almost 50% faster than vmap jit |
Beta Was this translation helpful? Give feedback.
Hello,
Short answer,
jit
'almost' always has to be the outer transformation.Long answer :
There are three small mistakes you did in your test.
Asynchronous dispatch
JAX runs everything asynchronously, so in your code the values are not guaranteed to be 'doubled' until you use them.
This means
The corret thing to do is
For more info async_dispatch
Don't profile the jit time
jitting more complex code takes more time. Functions are jit compiled the first time you run them and the subsequent executio…