-
I've translated Facebook Research's original Llama implementation to JAX version. Mostly it was Except that JAX version runs 40 times slower. If I interpret the profiler correctly, JAX spends most of the time on CPU while the prevalent operation on GPU is memory copy: If we break down the CPU section, numbers sum up - attention takes most of the time (33ms), followed by feedforward (11ms): JAX variables include parameters and cache values - I checked that both are allocated on GPU before the test. I also added a few To summarize:
Any idea why this may be happening? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I must be misunderstanding how JAX/Flax work, but I'm getting super strange results even on simpler examples. import jax
rng = jax.random.PRNGKey(0)
x = jax.random.normal(rng, (8, 4096))
w = jax.random.normal(rng, (4096, 1024))
x.device() # cuda(id=0)
w.device() # cuda(id=0)
import torch
pt_x = torch.randn((8, 4096)).to(torch.device("cuda"))
pt_w = torch.randn((4096, 1024)).to(torch.device("cuda"))
import timeit
N = 10_000
timeit.timeit(lambda: (x @ w).block_until_ready(), number=N) # 0.95 seconds
timeit.timeit(lambda: pt_x @ pt_w, number=N) # 0.24 seconds So simple matrix multiplication is ~4 times slower in JAX than in PyTorch. I also tested it with Flax modules, but the difference is almost ~100 times! import flax.linen as nn
dense = nn.Dense(1024, use_bias=False)
variables = dense.init(rng, x)
jax.tree_util.tree_leaves(variables)[0].device() # cuda(id=0)
timeit.timeit(lambda: dense.apply(variables, x).block_until_ready(), number=N) # 25.1 seconds (!)
import torch.nn as tnn
pt_dense = tnn.Linear(4096, 1024, bias=False).to(torch.device("cuda"))
timeit.timeit(lambda: pt_dense(pt_x), number=N) # 0.3 seconds Am I testing it the right way? If so, what can be the reason for such difference in performance? System:
|
Beta Was this translation helpful? Give feedback.
-
As discussed in another thread, JIT-compiling the code resolves the issue. After it, JAX version actually works slightly faster than PyTorch. |
Beta Was this translation helpful? Give feedback.
As discussed in another thread, JIT-compiling the code resolves the issue. After it, JAX version actually works slightly faster than PyTorch.