Skip to content

Commit 65fb8c6

Browse files
committed
merge latest
2 parents 965cdae + 7c1a597 commit 65fb8c6

File tree

22 files changed

+290
-214
lines changed

22 files changed

+290
-214
lines changed

lightllm/common/basemodel/batch_objs.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from dataclasses import dataclass, field
33
from typing import Optional
4+
from typing import List
45

56

67
@dataclass
@@ -20,6 +21,9 @@ class ModelInput:
2021

2122
# cpu 变量
2223
mem_indexes_cpu: torch.Tensor = None
24+
# prefill 阶段使用的参数,但是不是推理过程使用的参数,是推理外部进行资源管理
25+
# 的一些变量
26+
b_prefill_has_output_cpu: List[bool] = None # 标记进行prefill的请求是否具有输出
2327

2428
# 专有变量,用于一些特殊的模型,特殊的模式下, 传递一些特殊
2529
# 的输入变量。只在特殊的模型模式下才会具体使用和生效。

lightllm/common/basemodel/triton_kernel/gather_token_id.py

Lines changed: 65 additions & 130 deletions
Original file line numberDiff line numberDiff line change
@@ -4,86 +4,41 @@
44
import triton.language as tl
55

66

7-
@triton.jit
8-
def _fwd_kernel_gather_and_scatter(
9-
probs_idx,
10-
probs_sort,
11-
req_to_next_token_ids,
12-
req_to_next_token_probs,
13-
sampled_index,
14-
b_req_idx,
15-
probs_idx_stride,
16-
probs_sort_stride,
17-
req_to_next_token_ids_stride,
18-
req_to_next_token_probs_stride,
19-
):
20-
cur_index = tl.program_id(0)
21-
cur_req_idx = tl.load(b_req_idx + cur_index)
22-
cur_sampled_index = tl.load(sampled_index + cur_index)
23-
cur_token_index = tl.load(probs_idx + cur_index * probs_idx_stride + cur_sampled_index)
24-
cur_token_probs = tl.load(probs_sort + cur_index * probs_sort_stride + cur_sampled_index)
25-
tl.store(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride, cur_token_index)
26-
tl.store(req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride, tl.log(cur_token_probs))
27-
return
28-
29-
30-
@torch.no_grad()
31-
def gather_and_scatter_token_to_cpu(
32-
probs_idx: torch.Tensor,
33-
probs_sort: torch.Tensor,
34-
req_to_next_token_ids: torch.Tensor,
35-
req_to_next_token_probs: torch.Tensor,
36-
sampled_index: torch.Tensor,
37-
b_req_idx: torch.Tensor,
38-
):
39-
"""
40-
This function is used to gather the next_token_id(GPU tensor) and next_token_probs(GPU tensor)
41-
info to the req_to_next_token_ids and req_to_next_token_probs(CPU tensor).
42-
Args:
43-
probs_idx: (batch_size, vocab_size)
44-
probs_sort: (batch_size, vocab_size)
45-
req_to_next_token_ids: (max_req_num,)
46-
req_to_next_token_probs: (max_req_num,)
47-
sampled_index: (batch_size,)
48-
b_req_idx: (batch_size,)
49-
"""
50-
assert probs_idx.shape == probs_sort.shape
51-
assert sampled_index.shape[0] == b_req_idx.shape[0]
52-
batch_size = b_req_idx.shape[0]
53-
grid = (batch_size,)
54-
num_warps = 1
55-
56-
_fwd_kernel_gather_and_scatter[grid](
57-
probs_idx,
58-
probs_sort,
59-
req_to_next_token_ids,
60-
req_to_next_token_probs,
61-
sampled_index,
62-
b_req_idx,
63-
probs_idx.stride(0),
64-
probs_sort.stride(0),
65-
req_to_next_token_ids.stride(0),
66-
req_to_next_token_probs.stride(0),
67-
num_warps=num_warps,
68-
num_stages=1,
69-
)
70-
return
71-
72-
737
@triton.jit
748
def _fwd_kernel_scatter(
759
next_token_ids,
7610
req_to_next_token_ids,
7711
b_req_idx,
7812
b_mtp_index,
13+
b_has_out,
7914
req_to_next_token_ids_stride,
8015
req_to_next_token_ids_stride_1,
16+
num_size,
17+
HAS_OUT_IS_NONE: tl.constexpr,
18+
BLOCK: tl.constexpr,
8119
):
82-
cur_index = tl.program_id(0)
83-
cur_req_idx = tl.load(b_req_idx + cur_index)
84-
cur_mtp_index = tl.load(b_mtp_index + cur_index)
85-
cur_next_token_id = tl.load(next_token_ids + cur_index)
86-
tl.store(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, cur_next_token_id)
20+
block_index = tl.program_id(0)
21+
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
22+
block_mask = block_range < num_size
23+
24+
cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
25+
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
26+
cur_next_token_id = tl.load(next_token_ids + block_range, mask=block_mask)
27+
28+
if not HAS_OUT_IS_NONE:
29+
cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False)
30+
tl.store(
31+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
32+
cur_next_token_id,
33+
mask=cur_has_out & block_mask,
34+
)
35+
else:
36+
tl.store(
37+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
38+
cur_next_token_id,
39+
mask=block_mask,
40+
)
41+
8742
return
8843

8944

@@ -93,6 +48,7 @@ def scatter_token(
9348
req_to_next_token_ids: torch.Tensor,
9449
b_req_idx: torch.Tensor,
9550
b_mtp_index: torch.Tensor,
51+
b_has_out: torch.Tensor = None,
9652
):
9753
"""
9854
This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
@@ -104,16 +60,22 @@ def scatter_token(
10460
"""
10561
assert next_token_ids.shape[0] == b_req_idx.shape[0]
10662
batch_size = b_req_idx.shape[0]
107-
grid = (batch_size,)
63+
BLOCK = 256
64+
65+
grid = (triton.cdiv(batch_size, BLOCK),)
10866
num_warps = 1
10967

11068
_fwd_kernel_scatter[grid](
111-
next_token_ids,
112-
req_to_next_token_ids,
113-
b_req_idx,
114-
b_mtp_index,
115-
req_to_next_token_ids.stride(0),
116-
req_to_next_token_ids.stride(1),
69+
next_token_ids=next_token_ids,
70+
req_to_next_token_ids=req_to_next_token_ids,
71+
b_req_idx=b_req_idx,
72+
b_mtp_index=b_mtp_index,
73+
b_has_out=b_has_out,
74+
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
75+
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
76+
num_size=batch_size,
77+
HAS_OUT_IS_NONE=b_has_out is None,
78+
BLOCK=BLOCK,
11779
num_warps=num_warps,
11880
num_stages=1,
11981
)
@@ -128,12 +90,18 @@ def _fwd_kernel_gather(
12890
output,
12991
b_req_idx,
13092
b_mtp_index,
93+
num_size,
94+
BLOCK: tl.constexpr,
13195
):
132-
cur_index = tl.program_id(0)
133-
cur_req_idx = tl.load(b_req_idx + cur_index)
134-
cur_mtp_index = tl.load(b_mtp_index + cur_index)
135-
cur_next_token_id = tl.load(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index)
136-
tl.store(output + cur_index, cur_next_token_id)
96+
block_index = tl.program_id(0)
97+
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
98+
block_mask = block_range < num_size
99+
cur_req_idx = tl.load(b_req_idx + block_range, mask=block_mask)
100+
cur_mtp_index = tl.load(b_mtp_index + block_range, mask=block_mask)
101+
cur_next_token_id = tl.load(
102+
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index, mask=block_mask
103+
)
104+
tl.store(output + block_range, cur_next_token_id, mask=block_mask)
137105
return
138106

139107

@@ -148,72 +116,40 @@ def gather_token(req_to_next_token_ids: torch.Tensor, b_req_idx: torch.Tensor, b
148116
output: (batch_size,)
149117
"""
150118
batch_size = b_req_idx.shape[0]
151-
output = torch.empty_like(b_req_idx)
152-
grid = (batch_size,)
119+
output = torch.empty(batch_size, dtype=req_to_next_token_ids.dtype, device="cuda")
120+
BLOCK = 256
121+
grid = (triton.cdiv(batch_size, BLOCK),)
153122
num_warps = 1
154123
_fwd_kernel_gather[grid](
155-
req_to_next_token_ids,
156-
req_to_next_token_ids.stride(0),
157-
req_to_next_token_ids.stride(1),
158-
output,
159-
b_req_idx,
160-
b_mtp_index,
124+
req_to_next_token_ids=req_to_next_token_ids,
125+
req_to_next_token_ids_stride=req_to_next_token_ids.stride(0),
126+
req_to_next_token_ids_stride_1=req_to_next_token_ids.stride(1),
127+
output=output,
128+
b_req_idx=b_req_idx,
129+
b_mtp_index=b_mtp_index,
130+
num_size=batch_size,
131+
BLOCK=BLOCK,
161132
num_warps=num_warps,
162133
num_stages=1,
163134
)
164135
return output
165136

166137

167-
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
168-
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
169-
170-
probs_sum = torch.cumsum(probs_sort, dim=-1)
171-
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
172-
173-
probs_sort[torch.arange(0, probs.shape[-1], device="cuda").view(1, -1) >= top_ks.view(-1, 1)] = 0.0
174-
175-
return probs_sort, probs_idx
176-
177-
178-
def test_gather_and_scatter_token_to_cpu():
179-
batch_size = 30
180-
vocab_size = 60000
181-
req_to_next_token_ids = torch.ones((1000,), dtype=torch.int32, pin_memory=True)
182-
req_to_next_token_probs = torch.ones((1000,), dtype=torch.float32, pin_memory=True)
183-
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
184-
probs = torch.randn((batch_size, vocab_size)).cuda()
185-
top_ps = torch.rand((batch_size,)).cuda()
186-
top_ks = torch.ones((batch_size,), dtype=torch.int32).cuda()
187-
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
188-
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
189-
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index)
190-
batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index)
191-
192-
gather_and_scatter_token_to_cpu(
193-
probs_idx, probs_sort, req_to_next_token_ids, req_to_next_token_probs, sampled_index, req_ids
194-
)
195-
diff_ids = (req_to_next_token_ids[20 : 20 + batch_size].cuda() - batch_next_token_ids.view(-1)).abs().max()
196-
diff_probs = (req_to_next_token_probs[20 : 20 + batch_size].cuda() - batch_next_token_probs.view(-1)).abs().max()
197-
assert diff_ids < 1e-6
198-
assert diff_probs < 1e-6
199-
print("test_gather_and_scatter_token_to_cpu passed")
200-
201-
202138
def test_scatter_token_to_cpu():
203139
batch_size = 30
204-
req_to_token_info = torch.zeros((1000,), dtype=torch.float32, pin_memory=True)
140+
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
205141
token_info = torch.randn((batch_size,)).cuda()
206142
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
207143
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
208144
scatter_token(token_info, req_to_token_info, req_ids, mtp_index)
209-
diff = (req_to_token_info[20 : 20 + batch_size].cuda() - token_info).abs().max()
145+
diff = (req_to_token_info[20 : 20 + batch_size].cuda().view(-1) - token_info).abs().max()
210146
assert diff < 1e-6
211147
print("test_scatter_token_to_cpu passed")
212148

213149

214150
def test_gather_token():
215151
batch_size = 30
216-
req_to_token_info = torch.zeros((1000,), dtype=torch.int32, pin_memory=True)
152+
req_to_token_info = torch.zeros((1000, 1), dtype=torch.float32, pin_memory=True)
217153
token_info = torch.randn((batch_size,)).cuda()
218154
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
219155
mtp_index = torch.zeros((batch_size,), dtype=torch.int32).cuda()
@@ -225,6 +161,5 @@ def test_gather_token():
225161

226162

227163
if __name__ == "__main__":
228-
test_gather_and_scatter_token_to_cpu()
229164
test_scatter_token_to_cpu()
230165
test_gather_token()

0 commit comments

Comments
 (0)