Skip to content

Commit 581e2d0

Browse files
committed
update post
1 parent 3c1b535 commit 581e2d0

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

_posts/2025-11-23-d2l_attention-mechanisms.md

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ author: Pianfan
1010

1111
意识的聚集和专注使灵长类动物能够在复杂的视觉环境中将注意力引向感兴趣的物体,例如猎物和天敌。只关注一小部分信息的能力对进化更加有意义,使人类得以生存和成功<!-- more -->
1212

13-
### 10.1. 注意力提示
13+
## 10.1. 注意力提示
1414

1515
注意力是稀缺、有价值的资源,存在机会成本
1616

@@ -126,9 +126,19 @@ class NWKernelRegression(nn.Module):
126126

127127
## 10.3. 注意力评分函数
128128

129-
注意力汇聚输出:值的加权和,即 $f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i$
129+
注意力汇聚输出:值的加权和,即
130130

131-
注意力权重:由**注意力评分函数(attention scoring function)**经 softmax 得到,即 $\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))}$,其中 $a$ 为评分函数
131+
$$
132+
f(\mathbf{q}, (\mathbf{k}_1, \mathbf{v}_1), \ldots, (\mathbf{k}_m, \mathbf{v}_m)) = \sum_{i=1}^m \alpha(\mathbf{q}, \mathbf{k}_i) \mathbf{v}_i
133+
$$
134+
135+
注意力权重:由**注意力评分函数(attention scoring function)**经 softmax 得到,即
136+
137+
$$
138+
\alpha(\mathbf{q}, \mathbf{k}_i) = \frac{\exp(a(\mathbf{q}, \mathbf{k}_i))}{\sum_{j=1}^m \exp(a(\mathbf{q}, \mathbf{k}_j))}
139+
$$
140+
141+
其中 $a$ 为评分函数
132142

133143
### 10.3.1. 掩蔽 softmax 操作(masked softmax operation)
134144

@@ -269,38 +279,38 @@ $$
269279

270280
其中参数:$\mathbf{W}_i^{(q)} \in \mathbb{R}^{p_q \times d_q}$,$\mathbf W_i^{(k)}\in\mathbb R^{p_k\times d_k}$,$\mathbf W_i^{(v)}\in\mathbb R^{p_v\times d_v}$
271281

272-
最终输出:$\mathbf{W}_o \begin{bmatrix}\mathbf{h}_1\\\vdots\\\mathbf{h}_h\end{bmatrix} \in \mathbb{R}^{p_o}$,$\mathbf{W}_o \in \mathbb{R}^{p_o \times hp_v}$
282+
最终输出:$\mathbf{W}_o \begin{bmatrix} \mathbf{h}_1 \\ \vdots \\ \mathbf{h}_h \end{bmatrix} \in \mathbb{R}^{p_o}$,$\mathbf{W}_o \in \mathbb{R}^{p_o \times hp_v}$
273283

274284
### 10.5.2. 实现
275285

276286
1. **MultiHeadAttention 类**
277287

278-
包含参数:num_heads(头数)、attention(缩放点积注意力)、W_q/W_k/W_v/W_o(线性层)
288+
包含参数:`num_heads`(头数)、`attention`(缩放点积注意力)、`W_q`/`W_k`/`W_v`/`W_o`(线性层)
279289

280290
前向传播步骤:
281291

282292
- 对查询、键、值进行线性变换
283-
- 通过 transpose_qkv 转换形状以并行计算多头
293+
- 通过 `transpose_qkv` 转换形状以并行计算多头
284294
- 应用注意力机制
285-
- 通过 transpose_output 还原形状并经 W_o 输出
295+
- 通过 `transpose_output` 还原形状并经 `W_o` 输出
286296

287297
2. **关键变换函数**
288298

289-
transpose_qkv:将输入形状从 (batch_size, 序列长度num_hiddens) 转换为 (batch_size*num_heads, 序列长度num_hiddens/num_heads),实现多头并行
299+
`transpose_qkv`:将输入形状从 `(batch_size, 序列长度, num_hiddens)` 转换为 `(batch_size*num_heads, 序列长度, num_hiddens/num_heads)`,实现多头并行
290300

291-
transpose_output:逆转 transpose_qkv 的操作,拼接多头结果
301+
`transpose_output`:逆转 `transpose_qkv` 的操作,拼接多头结果
292302

293303
3. **参数设置**
294304

295305
通常设 $p_q = p_k = p_v = p_o/h$,避免计算和参数代价激增
296306

297-
num_hiddens 为输出特征维度
307+
`num_hiddens` 为输出特征维度
298308

299309
**输入输出形状**
300310

301-
输入:queries/keys/values 为 (batch_size, 序列长度num_hiddens)valid_lens 为 (batch_size,) 或 (batch_size, 查询数)
311+
输入:`queries`/`keys`/`values``(batch_size, 序列长度, num_hiddens)``valid_lens``(batch_size,)``(batch_size, 查询数)`
302312

303-
输出:(batch_size, 查询数num_hiddens)
313+
输出:`(batch_size, 查询数, num_hiddens)`
304314

305315
## 10.6. 自注意力和位置编码
306316

@@ -310,7 +320,7 @@ $$
310320

311321
公式:$\mathbf{y}_i = f(\mathbf{x}_i, (\mathbf{x}_1, \mathbf{x}_1), \ldots, (\mathbf{x}_n, \mathbf{x}_n)) \in \mathbb{R}^d$,其中 $f$ 为注意力汇聚函数
312322

313-
输入输出形状:(批量大小序列长度隐藏维度),输入输出形状相同
323+
输入输出形状:`(批量大小, 序列长度, 隐藏维度)`,输入输出形状相同
314324

315325
PyTorch 实现示例:
316326

@@ -406,8 +416,11 @@ class PositionalEncoding(nn.Module):
406416
4. **注意力机制**
407417

408418
多头自注意力:并行执行多个缩放点积注意力
419+
409420
编码器自注意力:查询、键、值均来自前一层编码器输出
421+
410422
解码器自注意力:仅允许关注当前位置及之前位置
423+
411424
编码器-解码器注意力:查询来自解码器,键/值来自编码器输出
412425

413426
### 10.7.3. 核心实现(PyTorch)

0 commit comments

Comments
 (0)