list of JAX/XLA compiler optimizations #9291
-
Hello! I regularly get surprised by how JAX (or perhaps rather XLA) seems to do some clever optimizations. For example if you write a function that multiplies a matrix by the inverse of another matrix, the resulting jitted function is as fast as if you hard-code the numerical trick of using a Choleksy factorization. So I'm guessing XLA does that optimization automatically. I've also noticed that XLA seems to recognise when a term is computed many times in a function; as a result the compiler seems to only compute it once. However, I can't find a list of these compiler optimizations on the JAX or XLA website. Is this documented somewhere? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 2 replies
-
I do agree with you that it would be interesting to know all optimisations performed by Jax/XLA. However it is my understanding that there are different levels where those optimisations happen.
That's most likely not a compiler optimisation but just the fact that the inverse of a matrix
In my experience CSE (Common subexpression elimination) is the largest XLA-level optimisation. I think it can eliminate 100% of repeated terms within an expression boundaries (boundaries are for example I think XLA might also performs some mathematical transformations like promoting operations like Once XLA is done, I believe it will output LLVM IR on which LLVM will perform several additional optimisations. Probably another CSE pass among them. But since LLVM passes can be configured, it might be interesting to know which passes are enabled and which are disabled... |
Beta Was this translation helpful? Give feedback.
I do agree with you that it would be interesting to know all optimisations performed by Jax/XLA. However it is my understanding that there are different levels where those optimisations happen.
Below is some of my understanding on the XLA compiler toolchain, but I'd also be interested to hear more on it.
That's most likely not a compiler optimisation but just the fact that the inverse of a matrix
A
is computed by callinglax.solve(A,identity)
which in turn uses a LU decomposition and a triangular solve. I call this lowering (of an operation to more fundamental_primitives_).This does not hit th…