-
I just experience that xformers library actually provides notable improvement when compiling the binary with given specific hardware architecture. It seems better than jax. Could jax make these kind of improvement when we build from source? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
Thanks for the question! The numerical code that JAX executes is just-in-time (JIT) compiled for whatever specific hardware you run it on, right before it's executed. That includes things like autotuning for your particular GPU, for example. So if I understand correctly, I don't think there's a change to be made here. |
Beta Was this translation helpful? Give feedback.
-
Then, is there any other acceleration library for enhancing the speed by leveraging the specific hardware architecture? For example, PyTorch has different numerous acceleration library such as triton, xformers, fastertransformer, megatron-lm, flexgen, fairscale. I just want to integrate them with jax. |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
The numerical code that JAX executes is just-in-time (JIT) compiled for whatever specific hardware you run it on, right before it's executed. That includes things like autotuning for your particular GPU, for example.
So if I understand correctly, I don't think there's a change to be made here.