Skip to content

Commit 678bb5f

Browse files
committed
inference overlap
1 parent af6f547 commit 678bb5f

File tree

13 files changed

+319
-86
lines changed

13 files changed

+319
-86
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,11 @@
1717
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
1818
from lightllm.common.basemodel.cuda_graph import CudaGraph
1919
from lightllm.common.quantization import Quantcfg
20+
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token_from_cpu
2021
from lightllm.utils.log_utils import init_logger
2122
from lightllm.utils.dist_utils import get_dp_world_size
2223
from lightllm.utils.envs_utils import get_env_start_args
23-
from lightllm.distributed.communication_op import CustomProcessGroup, dist_group_manager
24+
from lightllm.distributed.communication_op import dist_group_manager
2425
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2526
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
2627
from lightllm.utils.envs_utils import set_model_init_status
@@ -237,6 +238,7 @@ def _init_custom(self):
237238

238239
@torch.no_grad()
239240
def forward(self, model_input: ModelInput):
241+
model_input.to_cuda()
240242
assert model_input.mem_indexes.is_cuda
241243

242244
if model_input.is_prefill:
@@ -339,13 +341,20 @@ def _prefill(
339341
infer_state.mem_index,
340342
)
341343

342-
infer_state.init_some_extra_state(self, model_input.input_ids)
344+
infer_state.init_some_extra_state(self, model_input)
343345
return self._context_forward(model_input.input_ids, infer_state)
344346

345347
def _decode(
346348
self,
347349
model_input: ModelInput,
348350
) -> ModelOutput:
351+
# for overlap mode
352+
if model_input.input_ids is None:
353+
model_input.input_ids = gather_token_from_cpu(
354+
self.req_manager.req_sampling_params_manager.req_to_next_token_ids_cpu,
355+
model_input.b_req_idx,
356+
)
357+
349358
if self.graph is not None and self.graph.can_run(model_input.batch_size, model_input.max_len_in_batch):
350359
find_graph_batch_size = self.graph.find_closest_graph_batch_size(model_input.batch_size)
351360
padded_model_input = self._create_padded_decode_model_input(model_input, find_graph_batch_size)
@@ -356,7 +365,7 @@ def _decode(
356365
infer_state.b_seq_len,
357366
infer_state.mem_index,
358367
)
359-
infer_state.init_some_extra_state(self, padded_model_input.input_ids)
368+
infer_state.init_some_extra_state(self, padded_model_input)
360369

361370
if self.graph.need_capture(find_graph_batch_size):
362371
infer_state.is_cuda_graph = True
@@ -377,7 +386,7 @@ def _decode(
377386
infer_state.b_seq_len,
378387
infer_state.mem_index,
379388
)
380-
infer_state.init_some_extra_state(self, model_input.input_ids)
389+
infer_state.init_some_extra_state(self, model_input)
381390
model_output = self._token_forward(model_input.input_ids, infer_state)
382391

383392
return model_output

lightllm/common/basemodel/batch_objs.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,15 @@ class ModelInput:
2424
# 的 draft 模型的输入
2525
deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
2626

27+
def to_cuda(self):
28+
if self.input_ids is not None:
29+
self.input_ids = self.input_ids.cuda(non_blocking=True)
30+
self.mem_indexes = self.mem_indexes.cuda(non_blocking=True)
31+
self.b_req_idx = self.b_req_idx.cuda(non_blocking=True)
32+
self.b_seq_len = self.b_seq_len.cuda(non_blocking=True)
33+
if self.b_ready_cache_len is not None:
34+
self.b_ready_cache_len = self.b_ready_cache_len.cuda(non_blocking=True)
35+
2736

2837
@dataclass
2938
class ModelOutput:

lightllm/common/basemodel/infer_struct.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .triton_kernel.gen_prefill_params import gen_prefill_params
77
from .triton_kernel.gen_decode_params import gen_decode_params
88
from .triton_kernel.multimodal_emb import mark_multimodal_obj
9+
from .batch_objs import ModelInput
910

1011

1112
class InferStateInfo:
@@ -64,7 +65,7 @@ def __init__(self):
6465
# 的输入会用到,其他模型和场景都不会用到
6566
self.deepseekv3_mtp_draft_input_hiddens: Optional[torch.Tensor] = None
6667

67-
def init_some_extra_state(self, model, input_ids: torch.Tensor):
68+
def init_some_extra_state(self, model, model_input: ModelInput):
6869
if self.is_prefill:
6970
(
7071
self.b_q_seq_len,
@@ -75,7 +76,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
7576
self.max_q_seq_len,
7677
self.max_kv_seq_len,
7778
) = gen_prefill_params(
78-
input_token_num=input_ids.shape[0],
79+
input_token_num=model_input.input_ids.shape[0],
7980
b_ready_cache_len=self.b_ready_cache_len,
8081
b_seq_len=self.b_seq_len,
8182
)
@@ -87,10 +88,10 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor):
8788
self.b_kv_seq_len,
8889
self.b1_cu_kv_seq_len,
8990
self.position_ids,
90-
self.max_q_seq_len,
91-
self.max_kv_seq_len,
92-
) = gen_decode_params(b_seq_len=self.b_seq_len)
91+
) = gen_decode_params(self.b_seq_len)
9392
self.b_start_loc = self.b1_cu_kv_seq_len[0:-1]
93+
self.max_q_seq_len = 1
94+
self.max_kv_seq_len = model_input.max_len_in_batch
9495

9596
def copy_for_cuda_graph(self, new_infer_state: "InferStateInfo"):
9697
for attr_name, attr_value in vars(new_infer_state).items():
Lines changed: 209 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,209 @@
1+
import torch
2+
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _fwd_kernel_gather_and_scatter(
9+
probs_idx,
10+
probs_sort,
11+
req_to_next_token_ids,
12+
req_to_next_token_probs,
13+
sampled_index,
14+
b_req_idx,
15+
probs_idx_stride,
16+
probs_sort_stride,
17+
req_to_next_token_ids_stride,
18+
req_to_next_token_probs_stride,
19+
):
20+
cur_index = tl.program_id(0)
21+
cur_req_idx = tl.load(b_req_idx + cur_index)
22+
cur_sampled_index = tl.load(sampled_index + cur_index)
23+
cur_token_index = tl.load(probs_idx + cur_index * probs_idx_stride + cur_sampled_index)
24+
cur_token_probs = tl.load(probs_sort + cur_index * probs_sort_stride + cur_sampled_index)
25+
tl.store(req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride, cur_token_index)
26+
tl.store(req_to_next_token_probs + cur_req_idx * req_to_next_token_probs_stride, tl.log(cur_token_probs))
27+
return
28+
29+
30+
@torch.no_grad()
31+
def gather_and_scatter_token_to_cpu(
32+
probs_idx: torch.Tensor,
33+
probs_sort: torch.Tensor,
34+
req_to_next_token_ids: torch.Tensor,
35+
req_to_next_token_probs: torch.Tensor,
36+
sampled_index: torch.Tensor,
37+
b_req_idx: torch.Tensor,
38+
):
39+
"""
40+
This function is used to gather the next_token_id(GPU tensor) and next_token_probs(GPU tensor)
41+
info to the req_to_next_token_ids and req_to_next_token_probs(CPU tensor).
42+
Args:
43+
probs_idx: (batch_size, vocab_size)
44+
probs_sort: (batch_size, vocab_size)
45+
req_to_next_token_ids: (max_req_num,)
46+
req_to_next_token_probs: (max_req_num,)
47+
sampled_index: (batch_size,)
48+
b_req_idx: (batch_size,)
49+
"""
50+
assert probs_idx.shape == probs_sort.shape
51+
assert sampled_index.shape[0] == b_req_idx.shape[0]
52+
batch_size = b_req_idx.shape[0]
53+
grid = (batch_size,)
54+
num_warps = 1
55+
56+
_fwd_kernel_gather_and_scatter[grid](
57+
probs_idx,
58+
probs_sort,
59+
req_to_next_token_ids,
60+
req_to_next_token_probs,
61+
sampled_index,
62+
b_req_idx,
63+
probs_idx.stride(0),
64+
probs_sort.stride(0),
65+
req_to_next_token_ids.stride(0),
66+
req_to_next_token_probs.stride(0),
67+
num_warps=num_warps,
68+
num_stages=1,
69+
)
70+
return
71+
72+
73+
@triton.jit
74+
def _fwd_kernel_scatter(
75+
token_info,
76+
req_to_token_info,
77+
b_req_idx,
78+
req_to_token_info_stride,
79+
):
80+
cur_index = tl.program_id(0)
81+
cur_req_idx = tl.load(b_req_idx + cur_index)
82+
cur_token_info = tl.load(token_info + cur_index)
83+
tl.store(req_to_token_info + cur_req_idx * req_to_token_info_stride, cur_token_info)
84+
return
85+
86+
87+
@torch.no_grad()
88+
def scatter_token_to_cpu(token_info: torch.Tensor, req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor):
89+
"""
90+
This function is used to scatter the token_info(GPU tensor) to the req_to_token_info(CPU tensor).
91+
Args:
92+
token_info: (batch_size, vocab_size)
93+
req_to_token_info: (max_req_num,)
94+
b_req_idx: (batch_size,)
95+
"""
96+
assert token_info.shape[0] == b_req_idx.shape[0]
97+
batch_size = b_req_idx.shape[0]
98+
grid = (batch_size,)
99+
num_warps = 1
100+
101+
_fwd_kernel_scatter[grid](
102+
token_info,
103+
req_to_token_info,
104+
b_req_idx,
105+
req_to_token_info.stride(0),
106+
num_warps=num_warps,
107+
num_stages=1,
108+
)
109+
return
110+
111+
112+
@triton.jit
113+
def _fwd_kernel_gather(
114+
req_to_token_info,
115+
output,
116+
b_req_idx,
117+
):
118+
cur_index = tl.program_id(0)
119+
cur_req_idx = tl.load(b_req_idx + cur_index)
120+
cur_token_info = tl.load(req_to_token_info + cur_req_idx)
121+
tl.store(output + cur_index, cur_token_info)
122+
return
123+
124+
125+
def gather_token_from_cpu(req_to_token_info: torch.Tensor, b_req_idx: torch.Tensor):
126+
"""
127+
This function is used to gather the token_info(CPU tensor) to the token_info(GPU tensor).
128+
Args:
129+
req_to_token_info: (max_req_num,)
130+
b_req_idx: (batch_size,)
131+
Returns:
132+
output: (batch_size,)
133+
"""
134+
batch_size = b_req_idx.shape[0]
135+
output = torch.empty_like(b_req_idx)
136+
grid = (batch_size,)
137+
num_warps = 1
138+
_fwd_kernel_gather[grid](
139+
req_to_token_info,
140+
output,
141+
b_req_idx,
142+
num_warps=num_warps,
143+
num_stages=1,
144+
)
145+
return output
146+
147+
148+
def _top_p_top_k(probs: torch.Tensor, top_ps: torch.Tensor, top_ks: torch.Tensor):
149+
probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
150+
151+
probs_sum = torch.cumsum(probs_sort, dim=-1)
152+
probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
153+
154+
probs_sort[torch.arange(0, probs.shape[-1], device="cuda").view(1, -1) >= top_ks.view(-1, 1)] = 0.0
155+
156+
return probs_sort, probs_idx
157+
158+
159+
def test_gather_and_scatter_token_to_cpu():
160+
batch_size = 30
161+
vocab_size = 60000
162+
req_to_next_token_ids = torch.ones((1000,), dtype=torch.int32, pin_memory=True)
163+
req_to_next_token_probs = torch.ones((1000,), dtype=torch.float32, pin_memory=True)
164+
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
165+
probs = torch.randn((batch_size, vocab_size)).cuda()
166+
top_ps = torch.rand((batch_size,)).cuda()
167+
top_ks = torch.ones((batch_size,), dtype=torch.int32).cuda()
168+
probs_sort, probs_idx = _top_p_top_k(probs, top_ps, top_ks)
169+
sampled_index = torch.multinomial(probs_sort, num_samples=1, replacement=True)
170+
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index)
171+
batch_next_token_probs = torch.gather(probs_sort, dim=1, index=sampled_index)
172+
173+
gather_and_scatter_token_to_cpu(
174+
probs_idx, probs_sort, req_to_next_token_ids, req_to_next_token_probs, sampled_index, req_ids
175+
)
176+
diff_ids = (req_to_next_token_ids[20 : 20 + batch_size].cuda() - batch_next_token_ids.view(-1)).abs().max()
177+
diff_probs = (req_to_next_token_probs[20 : 20 + batch_size].cuda() - batch_next_token_probs.view(-1)).abs().max()
178+
assert diff_ids < 1e-6
179+
assert diff_probs < 1e-6
180+
print("test_gather_and_scatter_token_to_cpu passed")
181+
182+
183+
def test_scatter_token_to_cpu():
184+
batch_size = 30
185+
req_to_token_info = torch.zeros((1000,), dtype=torch.float32, pin_memory=True)
186+
token_info = torch.randn((batch_size,)).cuda()
187+
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
188+
scatter_token_to_cpu(token_info, req_to_token_info, req_ids)
189+
diff = (req_to_token_info[20 : 20 + batch_size].cuda() - token_info).abs().max()
190+
assert diff < 1e-6
191+
print("test_scatter_token_to_cpu passed")
192+
193+
194+
def test_gather_token_from_cpu():
195+
batch_size = 30
196+
req_to_token_info = torch.zeros((1000,), dtype=torch.int32, pin_memory=True)
197+
token_info = torch.randn((batch_size,)).cuda()
198+
req_ids = torch.arange(20, 20 + batch_size, dtype=torch.int32).cuda()
199+
scatter_token_to_cpu(token_info, req_to_token_info, req_ids)
200+
output = gather_token_from_cpu(req_to_token_info, req_ids)
201+
diff = (token_info - output).abs().max()
202+
assert diff < 1e-6
203+
print("test_gather_token_from_cpu passed")
204+
205+
206+
if __name__ == "__main__":
207+
test_gather_and_scatter_token_to_cpu()
208+
test_scatter_token_to_cpu()
209+
test_gather_token_from_cpu()

lightllm/common/basemodel/triton_kernel/gen_decode_params.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,5 @@ def gen_decode_params(b_seq_len: torch.Tensor):
1010
position_ids = b_seq_len - 1
1111
b_q_seq_len = torch.ones_like(b_seq_len)
1212
b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len)
13-
max_q_seq_len = b_q_seq_len.max().item()
14-
max_kv_seq_len = b_kv_seq_len.max().item()
15-
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids, max_q_seq_len, max_kv_seq_len
13+
14+
return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids

lightllm/common/req_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,12 @@ def __init__(self, max_request_num):
110110
self.req_to_presence_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
111111
self.req_to_frequency_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
112112
self.req_to_repetition_penalty = torch.zeros(max_request_num + 1, dtype=torch.float32, device="cuda")
113+
self.req_to_next_token_ids_cpu = torch.zeros(
114+
max_request_num + 1, dtype=torch.int32, device="cpu", pin_memory=True
115+
)
116+
self.req_to_next_token_probs_cpu = torch.zeros(
117+
max_request_num + 1, dtype=torch.float32, device="cpu", pin_memory=True
118+
)
113119
self.req_to_exponential_decay_length_penalty = torch.zeros(
114120
max_request_num + 1, dtype=torch.float32, device="cuda"
115121
)

0 commit comments

Comments
 (0)