@@ -160,3 +160,66 @@ Figure :numref:`ch-deploy/memory` shows the memory hierarchy with
160160corresponding bandwidths. The main goal of FlashAttention is to avoid
161161reading and writing the large attention matrix to and from HBM. And
162162perform computation in SRAM as much as possible.
163+
164+ The standard Scaled Dot-Product Attention [ @attention ] formula is
165+
166+ $$ \textbf{A} = Softmax(\frac{\textbf{QK}^T}{\sqrt{d_k}})\textbf{V} $$
167+ :eqlabel:` equ:std_attn `
168+
169+ As $d_k$ is a scalar, we can simplify it into three parts:
170+
171+ $$
172+ \begin{aligned}
173+ \textbf{S} = \textbf{QK}^T\\
174+ \textbf{P} = Softmax(\textbf{S})\\
175+ \textbf{O} = \textbf{PV}
176+ \end{aligned} $$
177+ :eqlabel:` equ:attn_sep `
178+
179+ The matrices ** K** , ** Q** , ** V** are all stored in HBM. The standard
180+ implementation of attention follows these steps:
181+
182+ 1 . Load ** K, Q** from HBM, compute ** $S$ = $QK^T$** , and write ** S** to
183+ the HBM.
184+
185+ 2 . Read ** S** from HBM, compute ** P** = $Softmax$(** S** ), and write
186+ ** P** to HBM.
187+
188+ 3 . Load ** P** and ** V** from HBM, compute ** O** = ** PV** , and write
189+ ** O** to HBM. Finally, return ** O** .
190+
191+ The standard implementation of attention involves frequent I/O
192+ interactions with HBM for large matrices reads/writes, leading to
193+ reduced speed due to the intensive memory access requirements. Moreover,
194+ it stores large intermediate matrices in HBM for backward propagation.
195+
196+ To handle such issues, FlashAttention divides the input components ** Q,
197+ K** , and ** V** into blocks. These blocks are then transferred from
198+ slower HBM to faster SRAM. Once in SRAM, the attention output is
199+ computed with respect to these blocks. Two strategies involved are
200+ called ** tiling** and ** recomputation** .
201+
202+ ** Tiling** : Assuming a vector $x\in \mathbb{R}^D$, the basic Softmax can
203+ be calculated as:
204+
205+ $$
206+ \begin{aligned}
207+ m(x) = \max\limits_{i} x_i\\
208+ l_{1}(x) = [e^{x_{1} - m(x)},\, ...\,,e^{x_{D} - m(x)}]\\
209+ s_{1}(x) = \sum_{i} l_{1}(x)_i\\
210+ Softmax(x) = \frac{l_{1}(x)}{s_{1}(x)}
211+ \end{aligned}
212+ $$
213+
214+ Attention can be computed by blocks, so large Softmax can be decomposed
215+ into separated parts. To elaborate, assuming a vector $x \in\mathbb{R}^{2D}$:
216+
217+ $$
218+ \begin{aligned}
219+ x = [x_{1}, \,x_{2}], \quad x_{1}, \, x_{2} \in\mathbb{R}^D\\
220+ m(x) = \max(m(x_{1}), \,m(x_{2}))\\
221+ l(x) = [e^{m(x_{1})-m_(x)}l_{1}(x_1),\, ... \, ,e^{m(x_2)-m(x)}l_{1}(x_2)]\\
222+ s(x) = e^{m(x_{1})-m(x)}s_{1}(x_1) + e^{m(x_2)-m(x)}s_{1}(x_2)\\
223+ Softmax(x) = \frac{l(x)}{s(x)}
224+ \end{aligned}
225+ $$
0 commit comments