Skip to content

Commit 8e4573e

Browse files
authored
在scaled_dot_product_attention函数中,加入3D的输入和输出 (#7353)
modified: docs/api/paddle/nn/functional/scaled_dot_product_attention_cn.rst
1 parent e29abf2 commit 8e4573e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

docs/api/paddle/nn/functional/scaled_dot_product_attention_cn.rst

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ scaled_dot_product_attention
1717
参数
1818
::::::::::
1919

20-
- **query** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
21-
- **key** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
22-
- **value** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
20+
- **query** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
21+
- **key** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
22+
- **value** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
2323
- **attn_mask** (Tensor, 可选) - 与添加到注意力分数的 ``query``、 ``key``、 ``value`` 类型相同的浮点掩码, 默认值为空。
2424
- **dropout_p** (float) - ``dropout`` 的比例, 默认值为 0.00 即不进行正则化。
2525
- **is_causal** (bool) - 是否启用因果关系, 默认值为 False 即不启用。
@@ -30,7 +30,7 @@ scaled_dot_product_attention
3030
返回
3131
::::::::::
3232

33-
- ``out`` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim]`` 的 4 维张量。数据类型可以是 float16 或 bfloat16。
33+
- ``out`` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim]`` 的 4 维张量或者形状为 ``[seq_len, num_heads, head_dim]`` 的 3 维张量。数据类型可以是 float16 或 bfloat16。
3434

3535

3636
代码示例

0 commit comments

Comments
 (0)