Skip to content

Commit d941cad

Browse files
committed
debug
1 parent b46e00e commit d941cad

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

chapter_model_deployment/Advanced_Efficient_Techniques.md

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,66 @@ Figure :numref:`ch-deploy/memory` shows the memory hierarchy with
160160
corresponding bandwidths. The main goal of FlashAttention is to avoid
161161
reading and writing the large attention matrix to and from HBM. And
162162
perform computation in SRAM as much as possible.
163+
164+
The standard Scaled Dot-Product Attention [@attention] formula is
165+
166+
$$\textbf{A} = Softmax(\frac{\textbf{QK}^T}{\sqrt{d_k}})\textbf{V}$$
167+
:eqlabel:`equ:std_attn`
168+
169+
As $d_k$ is a scalar, we can simplify it into three parts:
170+
171+
$$
172+
\begin{aligned}
173+
\textbf{S} = \textbf{QK}^T\\
174+
\textbf{P} = Softmax(\textbf{S})\\
175+
\textbf{O} = \textbf{PV}
176+
\end{aligned}$$
177+
:eqlabel:`equ:attn_sep`
178+
179+
The matrices **K**, **Q**, **V** are all stored in HBM. The standard
180+
implementation of attention follows these steps:
181+
182+
1. Load **K, Q** from HBM, compute **$S$ = $QK^T$**, and write **S** to
183+
the HBM.
184+
185+
2. Read **S** from HBM, compute **P** = $Softmax$(**S**), and write
186+
**P** to HBM.
187+
188+
3. Load **P** and **V** from HBM, compute **O** = **PV**, and write
189+
**O** to HBM. Finally, return **O**.
190+
191+
The standard implementation of attention involves frequent I/O
192+
interactions with HBM for large matrices reads/writes, leading to
193+
reduced speed due to the intensive memory access requirements. Moreover,
194+
it stores large intermediate matrices in HBM for backward propagation.
195+
196+
To handle such issues, FlashAttention divides the input components **Q,
197+
K**, and **V** into blocks. These blocks are then transferred from
198+
slower HBM to faster SRAM. Once in SRAM, the attention output is
199+
computed with respect to these blocks. Two strategies involved are
200+
called **tiling** and **recomputation**.
201+
202+
**Tiling**: Assuming a vector $x\in \mathbb{R}^D$, the basic Softmax can
203+
be calculated as:
204+
205+
$$
206+
\begin{aligned}
207+
m(x) = \max\limits_{i} x_i\\
208+
l_{1}(x) = [e^{x_{1} - m(x)},\, ...\,,e^{x_{D} - m(x)}]\\
209+
s_{1}(x) = \sum_{i} l_{1}(x)_i\\
210+
Softmax(x) = \frac{l_{1}(x)}{s_{1}(x)}
211+
\end{aligned}
212+
$$
213+
214+
Attention can be computed by blocks, so large Softmax can be decomposed
215+
into separated parts. To elaborate, assuming a vector $x \in\mathbb{R}^{2D}$:
216+
217+
$$
218+
\begin{aligned}
219+
x = [x_{1}, \,x_{2}], \quad x_{1}, \, x_{2} \in\mathbb{R}^D\\
220+
m(x) = \max(m(x_{1}), \,m(x_{2}))\\
221+
l(x) = [e^{m(x_{1})-m_(x)}l_{1}(x_1),\, ... \, ,e^{m(x_2)-m(x)}l_{1}(x_2)]\\
222+
s(x) = e^{m(x_{1})-m(x)}s_{1}(x_1) + e^{m(x_2)-m(x)}s_{1}(x_2)\\
223+
Softmax(x) = \frac{l(x)}{s(x)}
224+
\end{aligned}
225+
$$
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Further Reading
2+
3+
1. A Distributed Graph-Theoretic Framework for Automatic
4+
Parallelization in Multi-Core Systems[^1]
5+
6+
2. SCOP: Scientific Control for Reliable Neural Network Pruning[^2]
7+
8+
3. Searching for Low-Bit Weights in Quantized Neural Networks[^3]
9+
10+
4. GhostNet: More Features from Cheap Operations[^4]
11+
12+
5. AdderNet: Do We Really Need Multiplications in Deep Learning?[^5]
13+
14+
6. Blockwise Parallel Decoding for Deep Autoregressive Models[^6]
15+
16+
7. Medusa: Simple framework for accelerating LLM generation with
17+
multiple decoding heads[^7]
18+
19+
8. FlashAttention-2: Faster Attention with Better Parallelism and Work
20+
Partitioning[^8]
21+
22+
[^1]: <https://proceedings.mlsys.org/paper/2021/file/a5e00132373a7031000fd987a3c9f87b-Paper.pdf>
23+
24+
[^2]: <https://arxiv.org/abs/2010.10732>
25+
26+
[^3]: <https://arxiv.org/abs/2009.08695>
27+
28+
[^4]: <https://arxiv.org/abs/1911.11907>
29+
30+
[^5]: <https://arxiv.org/abs/1912.13200>
31+
32+
[^6]: <https://arxiv.org/abs/1811.03115>
33+
34+
[^7]: <https://www.together.ai/blog/medusa>
35+
36+
[^8]: <https://arxiv.org/abs/2307.08691>

0 commit comments

Comments
 (0)