Replies: 2 comments 1 reply
-
Thanks for the question! Is this on GPU or TPU? |
Beta Was this translation helpful? Give feedback.
1 reply
-
Any progress? |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi, in mixture of experts transformer block i have a computation that looks like this:
To be more specific, my embeddings' shape/sharding is the following:
[batch, expert, capacity, embedding] : [data, None, None, None]
I'm performing all-to-all by swapping 0 and 1 axes and resharding across new 0th axis (expert).
[batch, expert, capacity, embedding] -> [expert, batch, capacity, embedding]
[expert, batch, capacity, embedding] : [data, None, None, None]
It is expected that In large-scale setups All-to-All communications become very slow and to overcome this problem i'm trying to pipeline this computation over "capacity" axis to overlap communications and computations. I've implemented this simply by splitting capacity axis by some pipeline factor, running entire computation, and then stacking it again.

Also i've enabled XLA flags for latency hiding scheduler and async all-to-all communications.
But when i'm checking what is going on, I can see that XLA just fused all first all-to-all communications into blocking single one, so no computation is going on until it is finished. Second all-to-all, on the other hand, is well-overlapped.
How can i avoid that? Is there a way to put boundaries to XLA fusion mechanism or somehow disable fusions for specific part of my function?
Beta Was this translation helpful? Give feedback.
All reactions