-
I would like to start trying out Pallas for GPU (ideally just on colab). However I can't seem to figure out how to get started. It does not seem to be in 0.4.14. I installed from
What is the recommend installation procedure? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 7 replies
-
Thanks for the interest! TLDR; Here is Pallas on GPU working on colab: https://colab.research.google.com/drive/1P0d9S3lQHiDCyWvtXf8pdBWAyY1uBptZ?usp=sharing (No Pallas on TPU colab ready yet unfortunately.) The installation story is a bit complicated at the moment which is why we don't have an official "Installation guide" yet. Pallas is now part of JAX but depends on
The solution I am using in the colab is to find a set of pinned versions for everything that I know works. When Triton cuts a release on PyPI, we can do a JAX and JAX-Triton release simultaneously with pinned dependencies. Until then, the solution I recommend is to:
This is a not so great user experience for now, but we have some long term solutions to this. The most compelling one is to have Pallas not depend on JAX-Triton but depend on Triton bindings that ship with jaxlib itself (JAXlib contains Triton as well because XLA GPU uses Triton). |
Beta Was this translation helpful? Give feedback.
-
Stable Pallas installation is still tricky; work on this is being tracked in #18603 |
Beta Was this translation helpful? Give feedback.
(Nvm. Figured it out!)
Running into this issue:
Is this a jax / cuda versioning issue?