@@ -102,65 +102,3 @@ def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_
102102 )
103103 b_kv_seq_len = b_seq_len
104104 return b_q_seq_len , b1_cu_q_seq_len , b_kv_seq_len , b1_cu_kv_seq_len , position_ids
105-
106-
107- @triton .jit
108- def fill_req_to_token_indexes_kernel (
109- req_to_token_indexs_ptr , # [num_req, max_len]
110- b_req_idx_ptr , # [B]
111- b_seq_len_ptr , # [B]
112- b_ready_cache_len_ptr , # [B]
113- b_start_loc_ptr , # [B]
114- alloc_mem_index_ptr , # [total_new_tokens]
115- req_to_token_indexs_stride0 ,
116- req_to_token_indexs_stride1 ,
117- BLOCK : tl .constexpr ,
118- ):
119- pid = tl .program_id (0 ) # batch id
120- req_idx = tl .load (b_req_idx_ptr + pid )
121- cur_seq_len = tl .load (b_seq_len_ptr + pid )
122- cur_ready_cache_len = tl .load (b_ready_cache_len_ptr + pid )
123- start_loc = tl .load (b_start_loc_ptr + pid )
124-
125- copy_len = cur_seq_len - cur_ready_cache_len
126- if copy_len <= 0 :
127- return
128-
129- # 一次 BLOCK 个线程
130- offs = tl .arange (0 , BLOCK )
131- for base in range (0 , copy_len , BLOCK ):
132- idx = base + offs
133- mask = idx < copy_len
134- vals = tl .load (alloc_mem_index_ptr + start_loc + idx , mask = mask , other = 0 )
135-
136- out_ptrs = (
137- req_to_token_indexs_ptr
138- + req_idx * req_to_token_indexs_stride0
139- + (cur_ready_cache_len + idx ) * req_to_token_indexs_stride1
140- )
141- tl .store (out_ptrs , vals , mask = mask )
142-
143-
144- def init_req_to_token_indexes_triton (
145- req_to_token_indexs : torch .Tensor , # [num_req, max_len]
146- b_req_idx : torch .Tensor , # [B]
147- b_seq_len : torch .Tensor , # [B]
148- b_ready_cache_len : torch .Tensor , # [B]
149- b_start_loc : torch .Tensor , # [B], alloc_mem_index 的 prefix sum 起点
150- alloc_mem_index : torch .Tensor , # [total_new_tokens]
151- max_q_seq_len : int ,
152- ):
153- BLOCK = 128
154- batch_size = b_seq_len .shape [0 ]
155- grid = (batch_size ,)
156- fill_req_to_token_indexes_kernel [grid ](
157- req_to_token_indexs ,
158- b_req_idx ,
159- b_seq_len ,
160- b_ready_cache_len ,
161- b_start_loc ,
162- alloc_mem_index ,
163- req_to_token_indexs .stride (0 ),
164- req_to_token_indexs .stride (1 ),
165- BLOCK = BLOCK ,
166- )
0 commit comments