Skip to content

Commit fed664e

Browse files
authored
Fix (#3020)
1 parent aaad4c0 commit fed664e

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

accelerate-nd-parallel.md

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -157,41 +157,43 @@ Since the attention operation in transformers scales quadratically with context
157157

158158
With context parallelism (CP), we can shard the inputs across the sequence dimension, resulting in each device only processing a chunk of the full context and computing a smaller portion of the full, prohibitively large, attention matrix. To see how this works, recall that the attention computation is described by the equation:
159159

160-
$$\text{Attention}(Q, K, V) = \text{softmax}(QK^T)V$$
160+
\\( \text{Attention}(Q, K, V) = \text{softmax}(QK^T)V \\)
161161

162-
Where $Q$, $K$, and $V$ are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of $Q$ must compute the attention scores against *every* key vector of $K$ in the entire sequence to correctly apply the softmax normalisation. These attention scores are then weighted with *all* value vectors in $V$.
162+
Where \\( Q \\), \\( K \\), and \\( V \\) are the query, key, and value matrices respectively. Each query vector (row, or input embedding) of \\( Q \\) must compute the attention scores against *every* key vector of \\( K \\) in the entire sequence to correctly apply the softmax normalisation. These attention scores are then weighted with *all* value vectors in \\( V \\).
163163

164-
The crucial detail here lies in the fact that each row in $Q$ can compute its attention score independently of one another, but each query vector still requires the full $K$ and $V$ matrices. In other words, given an input with sequence length $n$, we can expand our above attention equation as:
164+
The crucial detail here lies in the fact that each row in \\( Q \\) can compute its attention score independently of one another, but each query vector still requires the full \\( K \\) and \\( V \\) matrices. In other words, given an input with sequence length $n$, we can expand our above attention equation as:
165165

166-
$$\begin{align}
166+
$$
167+
\begin{align}
167168
\text{Attention}(Q, K, V)_1 &= \text{softmax}(Q_1 K^T) V \\
168169
\text{Attention}(Q, K, V)_2 &= \text{softmax}(Q_2 K^T) V \\
169170
&\vdots \\
170171
\text{Attention}(Q, K, V)_n &= \text{softmax}(Q_n K^T) V
171-
\end{align}$$
172+
\end{align}
173+
$$
172174

173-
where we denote each row of the query matrix as $Q_1, Q_2, ..., Q_n$. This can be generalized as:
175+
where we denote each row of the query matrix as \\( Q_1, Q_2, ..., Q_n \\). This can be generalized as:
174176

175-
$$\text{Attention}(Q, K, V)_i = \text{softmax}(Q_i K^T) V \quad \forall i \in \{1, 2, ..., n\}$$
177+
\\( \text{Attention}(Q, K, V)_i = \text{softmax}(Q_i K^T) V \quad \forall i \in \{1, 2, ..., n\} \\)
176178

177-
When we shard the inputs across devices, the resulting $Q$, $K$, and $V$ matrices (computed from these input shards) are also automatically sharded along the sequence dimension - each GPU computes queries, keys, and values only for its portion of the sequence. For example, with a world size of $W$ GPUs and sequence length $n$:
179+
When we shard the inputs across devices, the resulting \\( Q \\), \\( K \\), and \\( V \\) matrices (computed from these input shards) are also automatically sharded along the sequence dimension - each GPU computes queries, keys, and values only for its portion of the sequence. For example, with a world size of \\( W \\) GPUs and sequence length \\( n \\):
178180

179-
- GPU 0 computes $Q_{1:n/W}$, $K_{1:n/W}$, $V_{1:n/W}$
180-
- GPU 1 computes $Q_{n/W+1:2n/W}$, $K_{n/W+1:2n/W}$, $V_{n/W+1:2n/W}$
181+
- GPU 0 computes \\( Q_{1:n/W} \\), \\( K_{1:n/W} \\), \\( V_{1:n/W} \\)
182+
- GPU 1 computes \\( Q_{n/W+1:2n/W} \\), \\( K_{n/W+1:2n/W} \\), \\( V_{n/W+1:2n/W} \\)
181183
- ...
182-
- GPU $(W-1)$ computes $Q_{(W-1)n/W+1:n}$, $K_{(W-1)n/W+1:n}$, $V_{(W-1)n/W+1:n}$
183-
184-
How do we ensure the attention is computed correctly? As established above, each device only needs its own shard of $Q$, but requires the full $K$ and $V$ matrices to compute the attention correctly. We can achieve this by using a technique called [RingAttention](https://openreview.net/forum?id=WsRHpHH4s0), which works as follows:
185-
1. Initially, each GPU holds its shard of $Q$, $K$, $V$ (e.g., GPU 0 holds $Q_{1:n/W}$, $K_{1:n/W}$,
186-
$V_{1:n/W}$).
187-
2. Each GPU then computes a partial attention matrix $A_{i,j}$ for its shard of $Q_i$ and its local
188-
shard of $K_j$, $V_j$.
189-
3. Each GPU sends its shard of $K$, $V$ to the next GPU in the ring.
190-
4. Each GPU receives a different shard of $K$, $V$ from the previous GPU in the ring.
191-
5. Each GPU computes additional partial attention matrices $A_{i,j+1}$, $A_{i,j+2}$, etc. using
192-
the received $K$, $V$ shards.
193-
6. Each GPU repeats this process until all shards of $K$, $V$ have been received and all partial
194-
attention matrices $A_{i,*}$ have been computed.
184+
- GPU \\( (W-1) \\) computes \\( Q_{(W-1)n/W+1:n} \\), \\( K_{(W-1)n/W+1:n} \\), \\( V_{(W-1)n/W+1:n} \\)
185+
186+
How do we ensure the attention is computed correctly? As established above, each device only needs its own shard of \\( Q \\), but requires the full \\( K \\) and \\( V \\) matrices to compute the attention correctly. We can achieve this by using a technique called [RingAttention](https://openreview.net/forum?id=WsRHpHH4s0), which works as follows:
187+
1. Initially, each GPU holds its shard of \\( Q \\), \\( K \\), \\( V \\) (e.g., GPU 0 holds \\( Q_{1:n/W} \\), \\( K_{1:n/W} \\),
188+
\\( V_{1:n/W} \\)).
189+
2. Each GPU then computes a partial attention matrix \\( A_{i,j} \\) for its shard of \\( Q_i \\) and its local
190+
shard of \\( K_j \\), \\( V_j \\).
191+
3. Each GPU sends its shard of \\( K \\), \\( V \\) to the next GPU in the ring.
192+
4. Each GPU receives a different shard of \\( K \\), \\( V \\) from the previous GPU in the ring.
193+
5. Each GPU computes additional partial attention matrices \\( A_{i,j+1} \\), \\( A_{i,j+2} \\), etc. using
194+
the received \\( K \\), \\( V \\) shards.
195+
6. Each GPU repeats this process until all shards of \\( K \\), \\( V \\) have been received and all partial
196+
attention matrices \\( A_{i,*} \\) have been computed.
195197

196198
<figure class="image text-center">
197199
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/blog/accelerate-nd-parallel/cp.png" alt="Diagram for Context Parallel">
@@ -303,4 +305,4 @@ Below are some additional tips you may find useful when working in distributed s
303305
cpu_ram_efficient_loading: True
304306
activation_checkpointing: True
305307
```
306-
Note that gradient checkpointing typically increases training time by ~20-30% due to activation recomputation, but can reduce activation memory by 60-80%, making it particularly valuable when training very large models or using long sequence lengths.
308+
Note that gradient checkpointing typically increases training time by ~20-30% due to activation recomputation, but can reduce activation memory by 60-80%, making it particularly valuable when training very large models or using long sequence lengths.

0 commit comments

Comments
 (0)