Sharded Matrices and How to Multiply Them | How To Scale Your Model #5
Replies: 37 comments 59 replies
-
In the solution for pop quiz 2, the bidirectional ICI bandwidth for a TPU v5e is given as 9e10 bytes/s, which doesn't quite match the value of 1e11 bytes/s given in the table in part 2. Looking at https://cloud.google.com/tpu/docs/v5e, it appears that the value in the table is the correct one. |
Beta Was this translation helpful? Give feedback.
-
In the Section - "A quick aside: how would we describe this in code?" "For instance, in the above example, the local shape of A is [4, 1024] and for B is [2048, 4096]" I think local shape of A is [2, 1024]? |
Beta Was this translation helpful? Give feedback.
-
On the first picture, you state that the shape of matrix A before sharding is [ For me, this would mean that A is sharded across its rows, and B is sharded across its columns, thus we have everything to calculate a single element of the result C, because the contracting dimensions are not sharded. But because you reversed the meaning if Could you enlight this with an image? I think I get the point, but visually it would help a lot what you mean exactly with these |
Beta Was this translation helpful? Give feedback.
-
Some issues with question 2: In part 2's solution, I think you mean for X to be in the denominator. The result is the same because X = Y in this case. In part 3's solution, you mention TPU v5e, but the question asks about v4p. In part 4, I'm not sure what AllGather with a {U_Z} dimension means. I believe this is not addressed in the text of the chapter. Also, the solution again mentions v5e. |
Beta Was this translation helpful? Give feedback.
-
The code below it and the in the code really says: 8 TPUs into 4x2 grid. |
Beta Was this translation helpful? Give feedback.
-
I believe question 4 may have miscalculated the comms overhead for |
Beta Was this translation helpful? Give feedback.
-
The flow in this chapter is a little jarring when it drops into the four cases without defining the term "contracting dimensions" or doing other setup to smooth the transition. Maybe an external reference or a bit more connective flow would help? |
Beta Was this translation helpful? Give feedback.
-
In the solution to question 4, I believe it should be D < C / Wici instead of F < C / Wici when calculating when we are comms bound in strategy 1. The wording is also a bit confusing because it says "In the second case (baseline)", but it appears to be talking about strategy 1 if I'm not mistaken? Also a small grammatical error at the end of the solution - "we'll shard our parameters" instead of "we'll sharded our parameters". |
Beta Was this translation helpful? Give feedback.
-
The text says "For example, A[IX,J]⋅B[J,K]→C[IX,K] can be multiplied without any communication because the contracting dimension (J, the one we’re actually summing over) is unsharded. However, if we wanted the output unsharded (i.e. A[IX,J]⋅B[J,K]→C[IX,K]), we would need to copy A or C to every device.". Presumably the last "C[IX,K]" should actually be "C[I,K]" |
Beta Was this translation helpful? Give feedback.
-
As someone who is fairly familiar with sharding and JAX, I think the flow of this chapter can be refined and the details (along with the notations) can be improved a lot. I am happy to contribute if you guys are open to contributions? I mean it when I say this is confusing and can be simplified |
Beta Was this translation helpful? Give feedback.
-
Can you explain more about AllReduce? Because I think I misunderstand what this actually do in Question2, Part 3. In my opinion, after we do because there is no communication between X and Y? |
Beta Was this translation helpful? Give feedback.
-
In Question 3, why the answer says "Since we have an axis of size 4 on a TPU v4p, we have a wraparound link, so we can do the AllGather by sending half the bytes in each direction". In the GIF above, I think each device sending the whole bytes in each direction? Is there any difference? |
Beta Was this translation helpful? Give feedback.
-
Thanks for the great work! I have a question in bi-directional all-gather case: since each hop sends |
Beta Was this translation helpful? Give feedback.
-
this is a fantastic book! Kudos to the authors and big THANK YOU! I think this section is super critical in appreciation of TPU differentiation vs GPU but needs quite substantial rework:
I hope my feedback is not misconstrued. I feel this book overall is phenomenal in its objectives and style, and definitely stands out in the crowd of similar efforts. Thank you again! |
Beta Was this translation helpful? Give feedback.
-
In question 4, I believe some math + reasoning for All-Gather being the preferred strategy is incorrect. At the beginning,
So for reasonably common batch sizes, we're ICI-bound for strategy 1, as we are for strategy 2. In that case, need to compare ICI times for both strategies to decide which one is best. Strategy 2 is best when:
So basically, for reasonable batch sizes (~1-2K) and D (~4K) strategy 2 is better than strategy 1. I also built a bunch of plots in this Colab, which showed that for certain large values of D & F it's never even beneficial to do strategy 1 (for example, when D=8K, F=16K) while for other values (D=4K, F=16K) it's better to do strategy 2 for B<2K and then it's slightly better to do strategy 1 for larger values of B Unless I screwed up doing my math above, I believe the recommendation that the "All-Gather" strategy is better for Case 2 should be reconsidered. At smaller batches, the "All-Reduce" strategy seems to be much better. It also makes sense when reasoning about it at the high level: when you have a giant weight matrix (i.e. -- A small nit re. the same question: it never mentions we want to do everything in bfloat16, would be great to add that info. -- Thank you for reading and also thank you for providing such a great learning resource for the community! |
Beta Was this translation helpful? Give feedback.
-
In question 10.1, why is the number of floats communicated by a ReduceScatter the same as that of AllGather? Doesn't ReduceScatter need to communicate less since the partial sums remain scattered and don't need to be gathered? |
Beta Was this translation helpful? Give feedback.
-
In Pop quiz 2 Part 1, I wonder if we should use unidirectional bandwidth (which is 4.5e10) because Y axis size is smaller than 16. IIUC, the answer should be Tcomms=34e6/4.5e10=756μs. I'm curious if I'm missing something. |
Beta Was this translation helpful? Give feedback.
-
Hi, thank you for the great explanation. I have a question regarding 10.2. Why is the data size considered to be |
Beta Was this translation helpful? Give feedback.
-
Could you clarify the I also don't have a great intuition of how an |
Beta Was this translation helpful? Give feedback.
-
In Case 3: both multiplicands have sharded contracting dimensions, is the reduce scatter done via bf16 or f32? typically matmul accumulation we need to do with f32, does it mean the communication cost for reduce scatter will be higher? |
Beta Was this translation helpful? Give feedback.
-
For question 7 I believe I might be misunderstanding the notation. We want to multiple matrices C and B, and take the result and multiply by matrix x correct? In this case, it appears the shapes are incompatible, the result of C * B is [F, F] which is incompatible with the shape x of [B, D]. |
Beta Was this translation helpful? Give feedback.
-
In the first pop quiz, you write that
but this is wrong because 128 * 2048 * 2 = 524,288 = 524kiB. |
Beta Was this translation helpful? Give feedback.
-
This isn't working in the Colab:
Should be changed to this:
|
Beta Was this translation helpful? Give feedback.
-
hello, i am confused that how to judge tpu has a wraparound connection ? is it relative to tpu's type |
Beta Was this translation helpful? Give feedback.
-
i can't understand "In 2D, the cost actually scales down with the size of the smallest axis.” in the article. why? can you give a example. Thank you every much. |
Beta Was this translation helpful? Give feedback.
-
in question 2, the first question, Is the formula should be 2BD/(9e10 * X), why Y is in the denominator? |
Beta Was this translation helpful? Give feedback.
-
In the "What happens when we AllGather over multiple axes" example, why does the latency-bound component of the total time depend on the sum of the length of each mesh axis, rather than their product? It seems to me that in each round of communication, each device cannot receive more than |
Beta Was this translation helpful? Give feedback.
-
Could I get some advice on how to approach question 5? I feel like I just went through the cases mentioned in this chapter and tried to derive when each is compute bound and comms bound. It feels like I should either: Also can I make use of the 3rd axis Z, it seems like as of now I'm only sharding with X, Y. |
Beta Was this translation helpful? Give feedback.
-
Small typo in Q6. I think the first bullet should have J_Y, not J_X. |
Beta Was this translation helpful? Give feedback.
-
In the description of bidirectional All-Gather in Case 2, it says " If we do two directions, we have ceil(N/2) hops of size 2⋅bytes / N". Should the number of hops not be floor(N/2) instead? When N is odd, e.g. N=7, we need only 3 hops. And when N is even, e.g. N=8, we need 4 hops. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Sharded matrix multiplications galore!
Beta Was this translation helpful? Give feedback.
All reactions