SSM (Mamba) in Jax #18907
Unanswered
faresobeid
asked this question in
Q&A
SSM (Mamba) in Jax
#18907
Replies: 1 comment 4 replies
-
Here is an implementation of gateloop in JAX I have not yet come to a full appreciation between these method but on a high level the ideas seem identical (make the coefficients of an SSM conditional on the input but not on the hidden state; thus retaining associative-scannability, while apparently getting sufficient modelling capacity to contest transformers, unlike fixed-coefficient SSMs). Also curious to what extent this performs well in plain JAX. Or if it doesnt itd be a nice test case for pallas I guess. |
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I recently came across the Mamba paper https://arxiv.org/abs/2312.00752 and their hardware aware algorithm which allows for fast speeds for a state space model in recurrent mode. There are also many other SSM's following this principle like RWKV v6. The only implementations that actually have good speeds I've seen are in cuda. I was wondering if and how they can be implemented in Jax to utilise TPU's, one thing I've seen is Pallas but so far I've had no success which might be due to my inexperience with it. The main premise of these algorithms can be seen in appendix D of the paper:

I understand that this might be GPU specific but surely there is a way to replicate the effects on TPU in jax.
Beta Was this translation helpful? Give feedback.
All reactions