Replies: 1 comment 4 replies
-
|
Beta Was this translation helpful? Give feedback.
4 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I'm trying to implement Mixtral in JAX and am facing one problem due to the way Mixtral works (regarding MoE block). If you aren't familiar, in Pytorch MoE works this way: after self-attention, we multiply router weights and hidden states (mind some additional processing) after which we get sort of a mapping between each token in the sequence and 2 experts. We then iterate over each expert, collect needed tokens, multiply by some parameters, collect results across experts and move on. The problem is, we can't quite do the same in JAX since number of tokens for each expert is data-dependant and not very consistent. Does it mean we cannot implement Mixtral in JAX? The Pytorch code snippet for MoE block will be provided as the next message.
Any suggestions on how to approach this are very welcome.
Edit: I've tried using a fixed number of tokens for each expert (seq_len // 8) but after some decoder layers hidden states seem to carry basically no information in them.
Beta Was this translation helpful? Give feedback.
All reactions