-
I'm currently working on an implementation of autoregressive decoding for transformers in jax. When my models are small, everything works fine. However, as models become larger, jit compilation time grows. But what is especially surprising, jit compilation time for a model with the same number of transformer blocks and heads, but with large MLP dimensions is significantly higher than for exactly the same model but with small MLP dims. I would have imagined that jit complexity should be independent of tensor shapes if the computation graph is the same. I should also probably mention that the compilation time of a forward or a forward-backward pass is small and doesn't seem to depend on parameter dimensions. Is this something that can potentially be happening or does it likely indicate a bug in jit implementation? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
Here’s a way that you could write a function such that compilation time grows with the shape of the input: @jit
def f(x):
y = 0
for i in range(x.shape[0]):
y += do_something(x)
return y The reason for this is that JAX’s JIT unrolls Python control flow, so here the effective size of the program grows proportionally to the size of the input array, and compilation time grows with program size. But if your program consists only of simple array operations without this kind of shape-dependent Python control flow, I wouldn’t generally expect compilation time to change with the shape of the input. |
Beta Was this translation helpful? Give feedback.
TL;DR: it seems that the problem was caused by the fact that model weights were captured by the decoding function instead of being passed in as an argument, so they were a huge constant from jit's perspective and, I guess, constant folding optimization wasn't particularly happy about them.
How I found this out: