Replies: 6 comments 2 replies
-
Thanks for the question! The issue is your use of nested for-loops within JIT. These will be unrolled by the JIT compiler, meaning that your program ends up sending a sequence of over 16000 individual instructions to XLA, which accounts for the slow compile times. Rather than nested for-loops, you should try expressing your program's logic in terms of vectorized array computations (similar to how Numpy achieves fast performance) or using jax-specific tools like vmap. I'd show an example based on your function, but it seems to be a bit over-simplified. Note also that if you call the jitted function a second time on similar input (so that compilation time is not included), execution will be very fast because XLA can optimize away these 16000 no-ops. |
Beta Was this translation helpful? Give feedback.
-
I just do not see how to use vmap in my function foo because I don't know Jax very well. Could you please help me in that case ? Thank you. |
Beta Was this translation helpful? Give feedback.
-
There's no way to use vmap in your function def foo(a):
return 0 I assume that |
Beta Was this translation helpful? Give feedback.
-
Is the benchmark example trying to measure dispatch time? Might be useful to figure out what you are trying to measure here? |
Beta Was this translation helpful? Give feedback.
-
Actually, I'd like to improve performance for the following functions : import jax.numpy as jnp
def f1(x):
res=0
for e in x:
res+=(3.5*e)**3;
return res
def f2(x):
res=jnp.zeros(sorties)
for i in range(len(res)):
res.at[i].set((x*i)**3)
return res |
Beta Was this translation helpful? Give feedback.
-
Thank you for answering me. I have another question : I have a Jax Tracer Object like that in a function : (Traced<ConcreteArray([1. 2. 3. 4. 5.])>with<JVPTrace(level=2/0)> |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
For my internship, I need to compare performances between Jax and Autograd.
But when I try to run that code (for example), I get worse performances with Jax than with Autograd :
The outputs I got are :
I don't know where this issue can come from, I work on Visual Studio Code and run my program on a GPU servor working on Ubuntu. The version of CUDA is the following :
I also tried the code on jupyther notebook and got the same results.
Thanks in advance for your help.
Beta Was this translation helpful? Give feedback.
All reactions