Skip to content

Commit 7a470d5

Browse files
author
wangzaijun
committed
rm nouse code
1 parent 42b4756 commit 7a470d5

File tree

1 file changed

+0
-62
lines changed

1 file changed

+0
-62
lines changed

lightllm/common/basemodel/triton_kernel/gen_prefill_params.py

Lines changed: 0 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -102,65 +102,3 @@ def gen_prefill_params(input_token_num: int, b_ready_cache_len: torch.Tensor, b_
102102
)
103103
b_kv_seq_len = b_seq_len
104104
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids
105-
106-
107-
@triton.jit
108-
def fill_req_to_token_indexes_kernel(
109-
req_to_token_indexs_ptr, # [num_req, max_len]
110-
b_req_idx_ptr, # [B]
111-
b_seq_len_ptr, # [B]
112-
b_ready_cache_len_ptr, # [B]
113-
b_start_loc_ptr, # [B]
114-
alloc_mem_index_ptr, # [total_new_tokens]
115-
req_to_token_indexs_stride0,
116-
req_to_token_indexs_stride1,
117-
BLOCK: tl.constexpr,
118-
):
119-
pid = tl.program_id(0) # batch id
120-
req_idx = tl.load(b_req_idx_ptr + pid)
121-
cur_seq_len = tl.load(b_seq_len_ptr + pid)
122-
cur_ready_cache_len = tl.load(b_ready_cache_len_ptr + pid)
123-
start_loc = tl.load(b_start_loc_ptr + pid)
124-
125-
copy_len = cur_seq_len - cur_ready_cache_len
126-
if copy_len <= 0:
127-
return
128-
129-
# 一次 BLOCK 个线程
130-
offs = tl.arange(0, BLOCK)
131-
for base in range(0, copy_len, BLOCK):
132-
idx = base + offs
133-
mask = idx < copy_len
134-
vals = tl.load(alloc_mem_index_ptr + start_loc + idx, mask=mask, other=0)
135-
136-
out_ptrs = (
137-
req_to_token_indexs_ptr
138-
+ req_idx * req_to_token_indexs_stride0
139-
+ (cur_ready_cache_len + idx) * req_to_token_indexs_stride1
140-
)
141-
tl.store(out_ptrs, vals, mask=mask)
142-
143-
144-
def init_req_to_token_indexes_triton(
145-
req_to_token_indexs: torch.Tensor, # [num_req, max_len]
146-
b_req_idx: torch.Tensor, # [B]
147-
b_seq_len: torch.Tensor, # [B]
148-
b_ready_cache_len: torch.Tensor, # [B]
149-
b_start_loc: torch.Tensor, # [B], alloc_mem_index 的 prefix sum 起点
150-
alloc_mem_index: torch.Tensor, # [total_new_tokens]
151-
max_q_seq_len: int,
152-
):
153-
BLOCK = 128
154-
batch_size = b_seq_len.shape[0]
155-
grid = (batch_size,)
156-
fill_req_to_token_indexes_kernel[grid](
157-
req_to_token_indexs,
158-
b_req_idx,
159-
b_seq_len,
160-
b_ready_cache_len,
161-
b_start_loc,
162-
alloc_mem_index,
163-
req_to_token_indexs.stride(0),
164-
req_to_token_indexs.stride(1),
165-
BLOCK=BLOCK,
166-
)

0 commit comments

Comments
 (0)