Skip to content

Commit c5bb0d0

Browse files
committed
fix
1 parent dfabc50 commit c5bb0d0

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

lightllm/models/llama/triton_kernel/ppl_int8kv_flash_decoding_diverse_stage1.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ def flash_decode_stage1(
161161
max_batch_group_size: int,
162162
):
163163
"""
164-
该kernel是为多样性生成定制的gqa算子,其中
164+
该kernel是为多样性生成定制的gqa算子,其中 b_mark_shared_group 是一个shape 为 (batch_size,)的tensor,
165+
其内容标记那些请求是共享前缀的请求组。举列说明:
166+
b_shared_seq_len : [10, 10, 10, 11, 11, 11, 11]
167+
b_mark_shared_group: [0, 0, 3, 0, 0, 0, 4]
168+
b_mark_shared_group 中每一个不为0的位置都代表其与前面多少个请求形成一个共享前缀组。属于
169+
同一个共享前缀组的请求, 其在对应的 b_shared_seq_len 中的内容必然相同。
165170
"""
166171
BLOCK_SEQ = block_seq
167172
BLOCK_N = 16

0 commit comments

Comments
 (0)