11import os
22import torch
33import torch .multiprocessing as mp
4+ import triton
5+ import triton .language as tl
46from typing import List
57from lightllm .utils .log_utils import init_logger
68from .gqa_flash_decoding_config import MlaDecodeAttentionKernelConfig
@@ -44,27 +46,19 @@ def gqa_token_decode_attention_flash_decoding(
4446 out_dtype = torch .bfloat16 ,
4547 )
4648
49+ BLOCK_N = run_config ["BLOCK_N" ]
50+
4751 from .gqa_flash_decoding_stage1 import flash_decode_stage1
4852 from .gqa_flash_decoding_stage2 import flash_decode_stage2
4953
5054 o_tensor = alloc_tensor_func (q_nope .shape , q_nope .dtype , q_nope .device ) if out is None else out
5155
52- mid_o_block_seq = torch .empty ([1 ], dtype = torch .int64 , device = "cuda" )
53- mid_o_batch_start_index = alloc_tensor_func (
54- [
55- batch_size ,
56- ],
57- dtype = torch .int64 ,
58- device = "cuda" ,
59- )
60-
56+ fake_decode_att_block_seq = torch .empty ([0 ], dtype = torch .int64 , device = "cuda" )
6157 mid_o = torch .empty ([q_head_num , 0 , kv_lora_rank ], dtype = torch .float32 , device = "cuda" )
6258 mid_o_logexpsum = torch .empty ([q_head_num , 0 ], dtype = torch .float32 , device = "cuda" )
6359
6460 vsm_count = flash_decode_stage1 (
65- infer_state .total_token_num_tensor ,
66- mid_o_block_seq ,
67- mid_o_batch_start_index ,
61+ fake_decode_att_block_seq ,
6862 q_nope .view (calcu_shape1 ),
6963 q_rope .view (calcu_shape2 ),
7064 kv_nope ,
@@ -79,13 +73,40 @@ def gqa_token_decode_attention_flash_decoding(
7973 ** run_config
8074 )
8175
76+ if not hasattr (infer_state , "decode_att_block_seq" ):
77+ assert batch_size <= 2048
78+ decode_att_block_seq = torch .empty (
79+ [
80+ 1 ,
81+ ],
82+ dtype = torch .int64 ,
83+ device = "cuda" ,
84+ )
85+ mid_o_batch_start_index = torch .empty (
86+ [
87+ batch_size ,
88+ ],
89+ dtype = torch .int64 ,
90+ device = "cuda" ,
91+ )
92+ _fwd_kernel_calcu_index_and_block_seq [(1 ,)](
93+ infer_state .b_seq_len ,
94+ decode_att_block_seq ,
95+ mid_o_batch_start_index ,
96+ vsm_count ,
97+ batch_size ,
98+ BLOCK_N = BLOCK_N ,
99+ num_warps = 4 ,
100+ )
101+
102+ infer_state .decode_att_block_seq = decode_att_block_seq
103+ infer_state .mid_o_batch_start_index = mid_o_batch_start_index
104+
82105 mid_o = torch .empty ([q_head_num , vsm_count * 4 + batch_size , kv_lora_rank ], dtype = torch .float32 , device = "cuda" )
83106 mid_o_logexpsum = torch .empty ([q_head_num , vsm_count * 4 + batch_size ], dtype = torch .float32 , device = "cuda" )
84107
85108 flash_decode_stage1 (
86- infer_state .total_token_num_tensor ,
87- mid_o_block_seq ,
88- mid_o_batch_start_index ,
109+ infer_state .decode_att_block_seq ,
89110 q_nope .view (calcu_shape1 ),
90111 q_rope .view (calcu_shape2 ),
91112 kv_nope ,
@@ -101,12 +122,35 @@ def gqa_token_decode_attention_flash_decoding(
101122 )
102123
103124 flash_decode_stage2 (
104- mid_o_block_seq ,
105- mid_o_batch_start_index ,
125+ infer_state . decode_att_block_seq ,
126+ infer_state . mid_o_batch_start_index ,
106127 mid_o ,
107128 mid_o_logexpsum ,
108129 infer_state .b_seq_len ,
109130 o_tensor .view (calcu_shape1 ),
110131 ** run_config
111132 )
112133 return o_tensor
134+
135+
136+ @triton .jit
137+ def _fwd_kernel_calcu_index_and_block_seq (
138+ b_seq_len_ptr ,
139+ mid_o_decode_att_block_seq_ptr ,
140+ mid_o_batch_start_index_ptr ,
141+ num_sm ,
142+ batch_size ,
143+ BLOCK_N : tl .constexpr ,
144+ ):
145+ b_seq_len = tl .load (b_seq_len_ptr + tl .arange (0 , 2048 ), mask = tl .arange (0 , 2048 ) < batch_size , other = 0 )
146+ total_token_num = tl .sum (b_seq_len )
147+
148+ block_seq = tl .cast (total_token_num / (num_sm * 4 ), dtype = tl .int32 ) + 1
149+ block_seq = tl .cdiv (block_seq , BLOCK_N ) * BLOCK_N
150+
151+ block_seq_len = tl .cdiv (b_seq_len , block_seq )
152+ cumsum_seq_len = tl .cumsum (block_seq_len )
153+ batch_start_index = cumsum_seq_len - block_seq_len
154+ tl .store (mid_o_batch_start_index_ptr + tl .arange (0 , 2048 ), batch_start_index , mask = tl .arange (0 , 2048 ) < batch_size )
155+ tl .store (mid_o_decode_att_block_seq_ptr , block_seq )
156+ return
0 commit comments