You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: accelerate-nd-parallel.md
+26-24Lines changed: 26 additions & 24 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -157,41 +157,43 @@ Since the attention operation in transformers scales quadratically with context
157
157
158
158
With context parallelism (CP), we can shard the inputs across the sequence dimension, resulting in each device only processing a chunk of the full context and computing a smaller portion of the full, prohibitively large, attention matrix. To see how this works, recall that the attention computation is described by the equation:
Where $Q$, $K$, and $V$ are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of $Q$ must compute the attention scores against *every* key vector of $K$ in the entire sequence to correctly apply the softmax normalisation. These attention scores are then weighted with *all* value vectors in $V$.
162
+
Where \\( Q \\), \\( K \\), and \\( V \\) are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of \\( Q \\) must compute the attention scores against *every* key vector of \\( K \\) in the entire sequence to correctly apply the softmax normalisation. These attention scores are then weighted with *all* value vectors in \\( V \\).
163
163
164
-
The crucial detail here lies in the fact that each row in $Q$ can compute its attention score independently of one another, but each query vector still requires the full $K$ and $V$ matrices. In other words, given an input with sequence length $n$, we can expand our above attention equation as:
164
+
The crucial detail here lies in the fact that each row in \\( Q \\) can compute its attention score independently of one another, but each query vector still requires the full \\( K \\) and \\( V \\) matrices. In other words, given an input with sequence length $n$, we can expand our above attention equation as:
165
165
166
-
$$\begin{align}
166
+
$$
167
+
\begin{align}
167
168
\text{Attention}(Q, K, V)_1 &= \text{softmax}(Q_1 K^T) V \\
168
169
\text{Attention}(Q, K, V)_2 &= \text{softmax}(Q_2 K^T) V \\
169
170
&\vdots \\
170
171
\text{Attention}(Q, K, V)_n &= \text{softmax}(Q_n K^T) V
171
-
\end{align}$$
172
+
\end{align}
173
+
$$
172
174
173
-
where we denote each row of the query matrix as $Q_1, Q_2, ..., Q_n$. This can be generalized as:
175
+
where we denote each row of the query matrix as \\( Q_1, Q_2, ..., Q_n \\). This can be generalized as:
174
176
175
-
$$\text{Attention}(Q, K, V)_i = \text{softmax}(Q_i K^T) V \quad \forall i \in \{1, 2, ..., n\}$$
177
+
\\( \text{Attention}(Q, K, V)_i = \text{softmax}(Q_i K^T) V \quad \forall i \in \{1, 2, ..., n\} \\)
176
178
177
-
When we shard the inputs across devices, the resulting $Q$, $K$, and $V$ matrices (computed from these input shards) are also automatically sharded along the sequence dimension - each GPU computes queries, keys, and values only for its portion of the sequence. For example, with a world size of $W$ GPUs and sequence length $n$:
179
+
When we shard the inputs across devices, the resulting \\( Q \\), \\( K \\), and \\( V \\) matrices (computed from these input shards) are also automatically sharded along the sequence dimension - each GPU computes queries, keys, and values only for its portion of the sequence. For example, with a world size of \\( W \\) GPUs and sequence length \\( n \\):
How do we ensure the attention is computed correctly? As established above, each device only needs its own shard of $Q$, but requires the full $K$ and $V$ matrices to compute the attention correctly. We can achieve this by using a technique called [RingAttention](https://openreview.net/forum?id=WsRHpHH4s0), which works as follows:
185
-
1. Initially, each GPU holds its shard of $Q$, $K$, $V$ (e.g., GPU 0 holds $Q_{1:n/W}$, $K_{1:n/W}$,
186
-
$V_{1:n/W}$).
187
-
2. Each GPU then computes a partial attention matrix $A_{i,j}$ for its shard of $Q_i$ and its local
188
-
shard of $K_j$, $V_j$.
189
-
3. Each GPU sends its shard of $K$, $V$ to the next GPU in the ring.
190
-
4. Each GPU receives a different shard of $K$, $V$ from the previous GPU in the ring.
191
-
5. Each GPU computes additional partial attention matrices $A_{i,j+1}$, $A_{i,j+2}$, etc. using
192
-
the received $K$, $V$ shards.
193
-
6. Each GPU repeats this process until all shards of $K$, $V$ have been received and all partial
How do we ensure the attention is computed correctly? As established above, each device only needs its own shard of \\( Q \\), but requires the full \\( K \\) and \\( V \\) matrices to compute the attention correctly. We can achieve this by using a technique called [RingAttention](https://openreview.net/forum?id=WsRHpHH4s0), which works as follows:
187
+
1. Initially, each GPU holds its shard of \\( Q \\), \\( K \\), \\( V \\) (e.g., GPU 0 holds \\( Q_{1:n/W} \\), \\( K_{1:n/W} \\),
188
+
\\( V_{1:n/W} \\)).
189
+
2. Each GPU then computes a partial attention matrix \\( A_{i,j} \\) for its shard of \\( Q_i \\) and its local
190
+
shard of \\( K_j \\), \\( V_j \\).
191
+
3. Each GPU sends its shard of \\( K \\), \\( V \\) to the next GPU in the ring.
192
+
4. Each GPU receives a different shard of \\( K \\), \\( V \\) from the previous GPU in the ring.
193
+
5. Each GPU computes additional partial attention matrices \\( A_{i,j+1} \\), \\( A_{i,j+2} \\), etc. using
194
+
the received \\( K \\), \\( V \\) shards.
195
+
6. Each GPU repeats this process until all shards of \\( K \\), \\( V \\) have been received and all partial
196
+
attention matrices \\( A_{i,*} \\) have been computed.
195
197
196
198
<figure class="image text-center">
197
199
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/accelerate-nd-parallel/cp.png" alt="Diagram for Context Parallel">
@@ -303,4 +305,4 @@ Below are some additional tips you may find useful when working in distributed s
303
305
cpu_ram_efficient_loading: True
304
306
activation_checkpointing: True
305
307
```
306
-
Note that gradient checkpointing typically increases training time by ~20-30% due to activation recomputation, but can reduce activation memory by 60-80%, making it particularly valuable when training very large models or using long sequence lengths.
308
+
Note that gradient checkpointing typically increases training time by ~20-30% due to activation recomputation, but can reduce activation memory by 60-80%, making it particularly valuable when training very large models or using long sequence lengths.
0 commit comments