Skip to content

Commit 1115890

Browse files
Add missed doc string for block_multihead_attention API (#60072) (#60139)
* add missed doc test=document_fix * test=document_fix
1 parent 78c5e68 commit 1115890

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

python/paddle/incubate/nn/functional/block_multihead_attention.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,27 @@ def block_multihead_attention(
6767
cu_seqlens_k (Tensor): The cum sequence lengths of key. Its shape is [batchsize + 1, 1].
6868
block_tables (Tensor): The block tables, used to index the cache. Its shape is [batchsize, block_num_per_seq].
6969
pre_key_cache (Tensor): The pre caches of key. Its shape is [batchsize, num_head, pre_cache_length, head_size].
70-
pre_key_value (Tensor): The pre caches of value. Its shape is [batchsize, num_head, pre_cache_length, head_size].
70+
pre_value_cache (Tensor): The pre caches of value. Its shape is [batchsize, num_head, pre_cache_length, head_size].
71+
cache_k_quant_scales (Tensor): The quant scales of cache key. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
72+
cache_v_quant_scales (Tensor): The quant scales of cache value. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
73+
cache_k_dequant_scales (Tensor): The dequant scales of cache key. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
74+
cache_v_dequant_scales (Tensor): The dequant scales of cache value. Its shape depends on quant mode (dynamic or static). If dynamic quantization is enabled, its shape is [batchsize, num_head], otherwise its shape is [num_head].
75+
qkv_out_scale (Tensor): The dequant scale of qkv, which is the input of BLHA. If the dtype of qkv is `int32`, this input will be applied. Its shape is [3 * num_head * head_size], and its dtype should be `float32`.
76+
qkv_bias (Tensor): The bias of qkv. Its shape is [3 * num_head * head_size].
77+
out_shift (Tensor): Shift bias of fmha_out, which is the 1st return value. Its shape is [num_head * head_size].
78+
out_smooth (Tensor): Smooth weight of fmha_out. Its shape is [num_head * head_size].
7179
rope_emb (Tensor): The RoPE embedding. Its shape is [2, batchsize, max_seq_len, 1, head_size // 2].
7280
mask (Tensor): The mask of qk_matmul in encoder. Its shape is [batchsize, 1, max_seq_len, max_seq_len].
7381
tgt_mask (Tensor): The mask of qk_matmul in decoder. Its shape is [batchsize, 1, 1, max_seq_len].
7482
max_seq_len (Int): The max length of the input. Default is -1.
7583
block_size (Int): The block_size of cache. Default is 64.
7684
use_neox_style (Bool): Whether neox_style RoPE is used or not. Default is False.
85+
use_dynamic_cachekv_quant (Bool): Whether dynamic cache kv quantization is applied or not. Default is False.
86+
quant_round_type (Int): The quant rount type in cache kv quantization and fmha_out quantization. If 0 is set, value will be rounding to nearest ties to even. If 1 is set, value will be rounding to nearest ties away from zero.
87+
quant_max_bound (Float32): The max bound of float type to int type.
88+
quant_min_bound (Float32): The min bound of float type to int type.
89+
out_scale (Float32): The quant scale of fmha_out. Default is -1, which means do not apply quantization for fmha_out.
90+
compute_dtype (Str): A compute dtype, is used to represent the input data type. Default is "default", which means compute dtype is determined by input dtype. However, if the dtype of input is Int32, this value should be set to actual dtype of the model.
7791
Returns:
7892
Tensor|(output, qkv_out, cache_k_out, cache_v_out), which output is the output of
7993
block_multihead_attention layers, qkv_out is inplace with input `qkv`, cache_k_out and cache_v_out are inplace with input `cache_k` and `cache_v`.
@@ -229,6 +243,14 @@ def block_multihead_attention(
229243
... block_tables,
230244
... None, # pre_key_cache
231245
... None, # pre_value_cache
246+
... None, # cache_k_quant_scales
247+
... None, # cache_v_quant_scales
248+
... None, # cache_k_dequant_scales
249+
... None, # cache_v_dequant_scales
250+
... None, # qkv_out_scale
251+
... None, # qkv_bias
252+
... None, # out_shift
253+
... None, # out_smooth
232254
... None, # rotary_embs
233255
... None, # attn_mask
234256
... None, # tgt_mask

0 commit comments

Comments
 (0)