Skip to content

Commit e2a39e4

Browse files
hiworldwzjwangzaijun
andauthored
better cuda graph mla decode kernel. (#694)
Co-authored-by: wangzaijun <[email protected]>
1 parent d5edea4 commit e2a39e4

23 files changed

+448
-430
lines changed
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
{"256": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_SEQ": 32, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 2}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_SEQ": 64, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "4096": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 8, "stage2_num_stages": 1}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 2, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_SEQ": 128, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 8, "stage2_num_stages": 1}, "8": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 5}, "16": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 5}, "32": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 64, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 5}, "64": {"BLOCK_SEQ": 256, "BLOCK_N": 32, "BLOCK_Q_HEAD": 32, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 5}}}
1+
{"256": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 2, "stage2_num_stages": 1}, "256": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 1}}, "512": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 3, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "1024": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 6, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}, "256": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 6, "stage2_num_warps": 4, "stage2_num_stages": 1}}, "2048": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 1, "stage2_num_stages": 1}, "256": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 1, "stage2_num_stages": 3}}, "4096": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 3, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "32": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 2, "stage2_num_stages": 1}, "64": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 1}, "128": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 5, "stage2_num_warps": 2, "stage2_num_stages": 1}}, "8192": {"1": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 2, "stage2_num_warps": 4, "stage2_num_stages": 3}, "8": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 4, "stage2_num_stages": 3}, "16": {"BLOCK_N": 32, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 8, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "32": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 4, "stage2_num_warps": 1, "stage2_num_stages": 3}, "64": {"BLOCK_N": 16, "BLOCK_Q_HEAD": 16, "stage1_num_warps": 4, "stage1_num_stages": 6, "stage2_num_warps": 1, "stage2_num_stages": 3}}}

lightllm/common/basemodel/basemodel.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from lightllm.common.basemodel.cuda_graph import CudaGraph
1919
from lightllm.common.quantization import Quantcfg
2020
from lightllm.utils.log_utils import init_logger
21-
from lightllm.common.basemodel.infer_lock import g_infer_state_lock
2221

2322
logger = init_logger(__name__)
2423

lightllm/models/deepseek2/infer_struct.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
1616
# 只有 decode 阶段使用 ppl 的优化算子才会有这个管理变量
1717
if not self.is_prefill:
1818
self.kv_starts = torch.cat([self.b_start_loc, self.b_start_loc[-1:] + self.b_seq_len[-1:]], dim=0)
19+
self.total_token_num_tensor = torch.sum(self.b_seq_len)
1920

2021
if self.enable_dp:
2122
rank = dist.get_rank()

lightllm/models/deepseek2/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import os
2-
import json
31
import torch
42
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
53
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
64
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
5+
from lightllm.models.deepseek2.splitfuse_infer_struct import DeepSeekv2SplitFuseInferStateInfo
76
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
87

98
from lightllm.models.llama.model import LlamaTpPartModel
@@ -24,6 +23,9 @@ class Deepseek2TpPartModel(LlamaTpPartModel):
2423
# infer state class
2524
infer_state_class = Deepseek2InferStateInfo
2625

26+
# split fuse state class
27+
splitfuse_infer_state_class = DeepSeekv2SplitFuseInferStateInfo
28+
2729
def __init__(self, kvargs):
2830
super().__init__(kvargs)
2931
return
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import torch
2+
import numpy as np
3+
from lightllm.common.basemodel import SplitFuseInferStateInfo
4+
from .infer_struct import Deepseek2InferStateInfo
5+
6+
7+
class DeepSeekv2SplitFuseInferStateInfo(SplitFuseInferStateInfo):
8+
9+
inner_infer_state_class = Deepseek2InferStateInfo
10+
11+
def __init__(self):
12+
super().__init__()
13+
self.position_cos = None
14+
self.position_sin = None
15+
16+
def init_some_extra_state(self, model, input_ids: torch.Tensor):
17+
position_ids = []
18+
if self.decode_req_num != 0:
19+
position_ids.append((self.decode_b_seq_len - 1).cpu().numpy())
20+
if self.prefill_req_num != 0:
21+
b_seq_len_numpy = self.prefill_b_seq_len.cpu().numpy()
22+
b_ready_cache_len_numpy = self.prefill_b_split_ready_cache_len.cpu().numpy()
23+
position_ids.extend(
24+
[np.arange(b_ready_cache_len_numpy[i], b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))]
25+
)
26+
27+
position_ids = torch.from_numpy(np.concatenate(position_ids, axis=0)).cuda().view(-1)
28+
self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1)
29+
self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1)
30+
return
31+
32+
def create_inner_decode_infer_status(self):
33+
infer_state = super().create_inner_decode_infer_status()
34+
infer_state.total_token_num_tensor = torch.sum(infer_state.b_seq_len)
35+
return

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import List
55
from lightllm.utils.log_utils import init_logger
66
from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
7+
from lightllm.utils.device_utils import get_device_sm_count
78

89
logger = init_logger(__name__)
910

@@ -43,36 +44,69 @@ def gqa_token_decode_attention_flash_decoding(
4344
out_dtype=torch.bfloat16,
4445
)
4546

46-
BLOCK_SEQ = run_config["BLOCK_SEQ"]
47-
4847
from .gqa_flash_decoding_stage1 import flash_decode_stage1
4948
from .gqa_flash_decoding_stage2 import flash_decode_stage2
5049

5150
o_tensor = alloc_tensor_func(q_nope.shape, q_nope.dtype, q_nope.device) if out is None else out
5251

53-
mid_o = alloc_tensor_func(
54-
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, kv_lora_rank], dtype=torch.float32, device="cuda"
52+
mid_o_block_seq = torch.empty([1], dtype=torch.int64, device="cuda")
53+
mid_o_batch_start_index = alloc_tensor_func(
54+
[
55+
batch_size,
56+
],
57+
dtype=torch.int64,
58+
device="cuda",
5559
)
56-
mid_o_logexpsum = alloc_tensor_func(
57-
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
60+
61+
mid_o = torch.empty([q_head_num, 0, kv_lora_rank], dtype=torch.float32, device="cuda")
62+
mid_o_logexpsum = torch.empty([q_head_num, 0], dtype=torch.float32, device="cuda")
63+
64+
vsm_count = flash_decode_stage1(
65+
infer_state.total_token_num_tensor,
66+
mid_o_block_seq,
67+
mid_o_batch_start_index,
68+
q_nope.view(calcu_shape1),
69+
q_rope.view(calcu_shape2),
70+
kv_nope,
71+
kv_rope,
72+
infer_state.req_manager.req_to_token_indexs,
73+
infer_state.b_req_idx,
74+
infer_state.b_seq_len,
75+
mid_o,
76+
mid_o_logexpsum,
77+
softmax_scale,
78+
get_sm_count=True,
79+
**run_config
5880
)
5981

82+
mid_o = torch.empty([q_head_num, vsm_count * 4 + batch_size, kv_lora_rank], dtype=torch.float32, device="cuda")
83+
mid_o_logexpsum = torch.empty([q_head_num, vsm_count * 4 + batch_size], dtype=torch.float32, device="cuda")
84+
6085
flash_decode_stage1(
86+
infer_state.total_token_num_tensor,
87+
mid_o_block_seq,
88+
mid_o_batch_start_index,
6189
q_nope.view(calcu_shape1),
6290
q_rope.view(calcu_shape2),
6391
kv_nope,
6492
kv_rope,
6593
infer_state.req_manager.req_to_token_indexs,
6694
infer_state.b_req_idx,
6795
infer_state.b_seq_len,
68-
infer_state.max_len_in_batch,
6996
mid_o,
7097
mid_o_logexpsum,
71-
BLOCK_SEQ,
7298
softmax_scale,
99+
get_sm_count=False,
73100
**run_config
74101
)
102+
75103
flash_decode_stage2(
76-
mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ, **run_config
104+
mid_o_block_seq,
105+
mid_o_batch_start_index,
106+
mid_o,
107+
mid_o_logexpsum,
108+
infer_state.b_seq_len,
109+
o_tensor.view(calcu_shape1),
110+
**run_config
77111
)
78112
return o_tensor

lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_config.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def try_to_get_best_config(
3838
return config
3939
else:
4040
config = {
41-
"BLOCK_SEQ": 64,
4241
"BLOCK_N": 16,
4342
"BLOCK_Q_HEAD": 16,
4443
"stage1_num_warps": 4,

0 commit comments

Comments
 (0)