@@ -17,9 +17,9 @@ scaled_dot_product_attention
17
17
参数
18
18
::::::::::
19
19
20
- - **query ** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
21
- - **key ** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
22
- - **value ** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim]。数据类型可以是 float61 或 bfloat16。
20
+ - **query ** (Tensor) - 注意力模块中的查询张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim] 。数据类型可以是 float61 或 bfloat16。
21
+ - **key ** (Tensor) - 注意力模块中的关键张量。具有以下形状的四维张量:[batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim] 。数据类型可以是 float61 或 bfloat16。
22
+ - **value ** (Tensor) - 注意力模块中的值张量。具有以下形状的四维张量: [batch_size, seq_len, num_heads, head_dim],或者三维张量:[seq_len, num_heads, head_dim] 。数据类型可以是 float61 或 bfloat16。
23
23
- **attn_mask ** (Tensor, 可选) - 与添加到注意力分数的 ``query ``、 ``key ``、 ``value `` 类型相同的浮点掩码, 默认值为空。
24
24
- **dropout_p ** (float) - ``dropout `` 的比例, 默认值为 0.00 即不进行正则化。
25
25
- **is_causal ** (bool) - 是否启用因果关系, 默认值为 False 即不启用。
@@ -30,7 +30,7 @@ scaled_dot_product_attention
30
30
返回
31
31
::::::::::
32
32
33
- - ``out `` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim] `` 的 4 维张量。数据类型可以是 float16 或 bfloat16。
33
+ - ``out `` (Tensor): 形状为 ``[batch_size, seq_len, num_heads, head_dim] `` 的 4 维张量或者形状为 `` [seq_len, num_heads, head_dim] `` 的 3 维张量。数据类型可以是 float16 或 bfloat16。
34
34
35
35
36
36
代码示例
0 commit comments