Skip to content

Commit 19146ca

Browse files
committed
docs: flash_attention
1 parent 1ca2359 commit 19146ca

File tree

1 file changed

+189
-0
lines changed

1 file changed

+189
-0
lines changed
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# `Flash Attention`
2+
3+
`Flash Attention`,是一种高效的**自注意力机制实现**,在加速注意力计算的同时也减少了内存占用。其核心原理是将输入分块,在每个块上分别进行注意力计算,从而减少对 HBM 的读写次数以提高计算效率。
4+
5+
标准 Attention 的计算为:
6+
1. 计算 $QK^T$,得到初步的注意力分数 attention_score;
7+
2. 添加位置编码以区别不同位置的元素,并乘以缩放系数;
8+
3. 根据注意力掩码屏蔽对应位置的元素,得到 masked_attention_score;
9+
4. 最后经过 softmax 计算,并与 $V$ 相乘,得到最终的注意力分数 attention_out。
10+
11+
从中可以发现,若 SRAM 无法存储完整的计算数据,则计算过程中需要频繁访问 HBM,I/O 需求为 $O(Nd+N^2)$。此前,由于 softmax 的计算需要遍历整个序列以得到整体最大值,所以优化的方式往往是通过近似的方式来降低数据占用的空间。
12+
13+
而 FlashAttention 给出了 softmax 的分块计算方式,以此代替原公式中的 softmax 计算,过程如下:
14+
1. 提前对输入的 QKV 进行分块(分别按照 QKV 的 sequence length 方向切分);
15+
2. 增量计算分块 softmax ,并维护两个全局变量
16+
- $m$:已处理块的最大值
17+
- $\ell$:已处理块的指数和
18+
19+
以计算 attention_out 中的第 j 块为例,需遍历 $Q$ 切分后的每一块
20+
1. $Q_0$ 参与计算时,正常计算,并将计算结果直接保存在 attention_out[j] 中,并更新 $m$ 与 $\ell$;
21+
2. 此后, $Q_i$ 参与计算时,先统计当前块 masked_attention_score 的最大值 $m_i$,并计算其指数和 $\ell_i$;
22+
```
23+
P = e^{masked_attention_score - m_i}
24+
l_i = rowsum(P)
25+
```
26+
3. 更新 $m_{new}$ 为 $\max(m_i,m)$ ,并根据新的最大值计算 $\ell_{new}$,再由此计算新的 attention_out[j];
27+
```
28+
l_new = e^{m-m_new}*l+e^{m_i-m_new}
29+
attention_out = (attention_out * l * e^{m-m_new} + e^{m_i-m_new} * P * V_j) / l_new
30+
```
31+
4. 分别用 $m_{new}$ 与 $\ell_{max}$ 更新 $m$ 与 $\ell$,开启新一块的计算;
32+
5. 遍历完 $Q$ 之后,即可得到完整的 attention_out[j]
33+
34+
FlashAttention-2 则在 FlashAttention 的基础上进行了两点改进:
35+
36+
1. 在计算局部 Attention 时,先不考虑分母的指数和,而是 $\mathbf{O}_{i}^{(j)}=\mathrm{diag}(e^{m_{i}^{(j-1)}-m_{i}^{(j)}})\mathbf{O}_{i}^{(j-1)}+\tilde{\mathbf{P}}_{i}^{(j)}\mathbf{V}_{j}$
37+
2. 在最后一步再带入计算,得到正确的结果 $\mathbf{O}_{i}=\mathrm{diag}(\ell_{i}^{(T_{c})})^{-1}\mathbf{O}_{i}^{(T_{c})}$
38+
39+
同时,FlashAttention-2 将 $Q$ 移到了外循环,而 $K$ 移到了内循环。由于改进了算法,使得 warps 之间不需要相互通信,所以外循环可以放在不同的 thread block 上。
40+
41+
42+
## 接口
43+
44+
### 计算
45+
46+
```c
47+
infiniStatus_t infiniopFlashAttention(
48+
infiniopFlashAttentionDescriptor_t desc,
49+
void *workspace,
50+
size_t workspace_size,
51+
void *out,
52+
void *l,
53+
const void *q,
54+
const void *k,
55+
const void *v,
56+
const void *mask,
57+
void *stream
58+
);
59+
```
60+
61+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
62+
63+
- `desc`:
64+
已使用 `infiniopCreateFlashAttentionDescriptor()` 初始化的算子描述符;
65+
- `workspace`:
66+
指向算子计算所需的额外工作空间;
67+
- `workspace_size`:
68+
`workspace` 的大小,单位:字节;
69+
- `out`:
70+
注意力计算结果地址。张量限制见[创建算子描述](#创建算子描述)部分。
71+
- `l`:
72+
logsumexp,用于反向传播计算的暂存结果。张量限制见[创建算子描述](#创建算子描述)部分。
73+
- `q`:
74+
查询(Query)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
75+
- `k`:
76+
键(Key)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
77+
- `v`:
78+
值(Value)张量数据指针。张量限制见[创建算子描述](#创建算子描述)部分。
79+
- `mask`:
80+
注意力掩码的数据指针,可选参数。取值为 `0` 时表示保留对应位置的元素(参与计算);取值为 `-inf` (负无穷)时表示屏蔽对应位置的元素(即跳过,不参与计算)。张量限制见[创建算子描述](#创建算子描述)部分。
81+
- `stream`:
82+
计算流/队列。
83+
84+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
85+
86+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_NULL_POINTER`], [`INFINI_STATUS_INSUFFICIENT_WORKSPACE`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`], [`INFINI_STATUS_INTERNAL_ERROR`].
87+
88+
### 创建算子描述
89+
90+
```c
91+
infiniStatus_t infiniopCreateFlashAttentionDescriptor(
92+
infiniopHandle_t handle,
93+
infiniopFlashAttentionDescriptor_t *desc_ptr,
94+
infiniopTensorDescriptor_t out_desc,
95+
infiniopTensorDescriptor_t l_desc,
96+
infiniopTensorDescriptor_t q_desc,
97+
infiniopTensorDescriptor_t k_desc,
98+
infiniopTensorDescriptor_t v_desc,
99+
infiniopTensorDescriptor_t mask_desc,
100+
infiniopAttentionMaskType_t mask_type)
101+
```
102+
103+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
104+
105+
- `handle`:
106+
`infiniopHandle_t`类型的硬件控柄。详见 [`InfiniopHandle_t`]。
107+
- `desc_ptr`:
108+
`infiniopFlashAttentionDescriptor_t` 指针,指向将被初始化的算子描述符地址。
109+
- `out_desc` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
110+
算子计算参数 `out` 的张量描述,四维或者三维,最后一维连续。
111+
- `l_desc` - { dT | ((batch_szie,) seq_len_q, num_heads_q) | ($\ldots, 1$)};
112+
算子计算参数 `l` 的张量描述,三维或者二维。
113+
- `q_desc` - { dT | ((batch_size,) seq_len_q, num_heads_q, head_dim) | ($\ldots, 1$)}:
114+
算子计算参数 `q` 的张量描述,形状与 `out_desc` 一致,最后一维连续。
115+
- `k_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
116+
算子计算参数 `k` 的张量描述,形状与 `out_desc` 一致,最后一维连续。
117+
- `v_desc` - { dT | ((batch_size,) seq_len_kv, num_heads_kv, head_dim) | ($\ldots, 1$)}:
118+
算子计算参数 `v` 的张量描述,形状与 `out_desc` 一致,最后一维连续。
119+
- `mask_desc` - { dM | (seq_len_q, seq_len_kv) | (~)}:
120+
算子计算参数 `mask` 的张量描述,当 `mask_type=INFINIOP_ATTENTION_MASK_TYPE_FULL` 时,`mask` 不可为空,其余情况 `mask` 可为`nullptr`。
121+
- `mask_type` - `infiniopAttentionMaskType_t`:
122+
注意力类型参数,有三种类型可选,详细见参数限制。
123+
124+
参数限制:
125+
126+
- `dT`: `Float16`, `Float32` 或 `BFloat16`。
127+
- `dM`: `Flaot32`。
128+
- `seq_len_q` 与 `seq_len_kv` 可以不同。
129+
- `num_heads_q` 与 `num_heads_kv` 可以不同,但需满足前者是后者的整数倍(非0整数)。
130+
- 当 $N_q/N_{kv}=1$ 时,即为 MQA (multi-query attention)
131+
- 当 $N_q/N_{kv}>1$ 时,即为 GQA (grouped-query attention)
132+
- `mask_type` 的三种类型:
133+
- `INFINIOP_ATTENTION_MASK_TYPE_NONE=0`: 不使用注意力掩码,忽略 `mask` 取值;
134+
- `INFINIOP_ATTENTION_MASK_TYPE_FULL=1`: 使用完整 mask 矩阵,此时 `mask` 不可为空;
135+
- `INFINIOP_ATTENTION_MASK_TYPE_CAUSAL=2`: 使用标准因果掩码,对应以左上顶点划分的下三角场景,忽略 `mask` 取值;
136+
137+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
138+
139+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_BAD_TENSOR_SHAPE`], [`INFINI_STATUS_BAD_TENSOR_DTYPE`], [`INFINI_STATUS_BAD_TENSOR_STRIDES`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
140+
141+
### 计算额外工作空间
142+
143+
```c
144+
infiniStatus_t infiniopGetFlashAttentionWorkspaceSize(
145+
infiniopFlashAttentionDescriptor_t desc,
146+
size_t *size
147+
);
148+
```
149+
150+
<div style="background-color: lightblue; padding: 1px;"> 参数:</div>
151+
152+
- `desc`:
153+
已使用 `infiniopCreateFlashAttentionDescriptor()` 初始化的算子描述符;
154+
- `size`:
155+
额外空间大小的计算结果的写入地址;
156+
157+
<div style="background-color: lightblue; padding: 1px;"> 返回值:</div>
158+
159+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_BAD_PARAM`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
160+
161+
### 销毁算子描述符
162+
163+
```c
164+
infiniStatus_t infiniopDestoryFlashAttentionDescriptor(
165+
infiniopFlashAttentionDescriptor_t desc
166+
);
167+
```
168+
169+
<div style="background-color: lightblue; padding: 1px;"> 参数: </div>
170+
171+
- `desc`:
172+
输入。 待销毁的算子描述符;
173+
174+
<div style="background-color: lightblue; padding: 1px;"> 返回值: </div>
175+
176+
- [`INFINI_STATUS_SUCCESS`], [`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`].
177+
178+
<!-- 链接 -->
179+
[`InfiniopHandle_t`]: /infiniop/handle/README.md
180+
181+
[`INFINI_STATUS_SUCCESS`]: /common/status/README.md#INFINI_STATUS_SUCCESS
182+
[`INFINI_STATUS_BAD_PARAM`]: /common/status/README.md#INFINI_STATUS_BAD_PARAM
183+
[`INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED`]: /common/status/README.md#INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
184+
[`INFINI_STATUS_BAD_TENSOR_SHAPE`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_SHAPE
185+
[`INFINI_STATUS_BAD_TENSOR_DTYPE`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_DTYPE
186+
[`INFINI_STATUS_BAD_TENSOR_STRIDES`]: /common/status/README.md#INFINI_STATUS_BAD_TENSOR_STRIDES
187+
[`INFINI_STATUS_NULL_POINTER`]:/common/status/README.md#INFINI_STATUS_NULL_POINTER
188+
[`INFINI_STATUS_INSUFFICIENT_WORKSPACE`]:/common/status/README.md#INFINI_STATUS_INSUFFICIENT_WORKSPACE
189+
[`INFINI_STATUS_INTERNAL_ERROR`]:/common/status/README.md#INFINI_STATUS_INTERNAL_ERROR

0 commit comments

Comments
 (0)