Skip to content

Commit ef28098

Browse files
hiworldwzjwangzaijun
andauthored
prefill cuda graph. (#1149)
Co-authored-by: wangzaijun <[email protected]>
1 parent 974d775 commit ef28098

File tree

20 files changed

+644
-171
lines changed

20 files changed

+644
-171
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 142 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import json
77
import torch
88
import torch.nn.functional as F
9-
from typing import final
9+
from typing import final, List
1010
from tqdm import tqdm
1111

1212
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
@@ -19,6 +19,7 @@
1919
from lightllm.common.basemodel.triton_kernel.copy_kv_index_to_req import copy_kv_index_to_req
2020
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
2121
from lightllm.common.basemodel.cuda_graph import CudaGraph
22+
from lightllm.common.basemodel.prefill_cuda_graph import PrefillCudaGraph
2223
from lightllm.common.quantization import Quantcfg
2324
from lightllm.common.basemodel.triton_kernel.gather_token_id import gather_token
2425
from lightllm.utils.log_utils import init_logger
@@ -89,6 +90,7 @@ def __init__(self, kvargs):
8990
self.enable_tpsp_mix_mode = get_env_start_args().enable_tpsp_mix_mode
9091

9192
self.is_deepseekv3_mtp_mode = self.args.mtp_mode in ["deepseekv3_vanilla", "deepseekv3_eagle"]
93+
self.prefill_graph: PrefillCudaGraph = None
9294

9395
self._init_config()
9496
self._verify_must()
@@ -115,6 +117,7 @@ def __init__(self, kvargs):
115117
# wait必须在init cudagraph 之前,避免错误捕获
116118
self._wait_other_modules_ready()
117119
self._init_cudagraph()
120+
self._init_prefill_cuda_graph()
118121
self._check_max_len_infer()
119122
torch.cuda.empty_cache()
120123
set_model_init_status(True)
@@ -240,6 +243,18 @@ def _init_cudagraph(self):
240243
else:
241244
self.graph.warmup(self)
242245

246+
def _init_prefill_cuda_graph(self):
247+
self.prefill_graph = (
248+
None
249+
if not get_env_start_args().enable_prefill_cudagraph
250+
else PrefillCudaGraph(decode_cuda_graph=self.graph)
251+
)
252+
if self.prefill_graph is not None:
253+
if get_env_start_args().enable_prefill_microbatch_overlap:
254+
self.prefill_graph.warmup_overlap(self)
255+
else:
256+
self.prefill_graph.warmup(self)
257+
243258
def _init_custom(self):
244259
pass
245260

@@ -332,6 +347,48 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
332347

333348
return new_model_input
334349

350+
def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle_token_num: int):
351+
assert model_input.total_token_num - model_input.prefix_total_token_num < new_handle_token_num
352+
353+
padded_token_num = new_handle_token_num - (model_input.total_token_num - model_input.prefix_total_token_num)
354+
assert padded_token_num > 0
355+
new_model_input = copy.copy(model_input)
356+
new_model_input.batch_size = model_input.batch_size + 1
357+
new_model_input.total_token_num += padded_token_num
358+
new_model_input.max_len_in_batch = max(padded_token_num, model_input.max_len_in_batch)
359+
new_model_input.max_q_seq_len = max(padded_token_num, model_input.max_q_seq_len)
360+
new_model_input.max_kv_seq_len = max(padded_token_num, model_input.max_kv_seq_len)
361+
new_model_input.max_cache_len = max(0, model_input.max_cache_len)
362+
new_model_input.input_ids = F.pad(new_model_input.input_ids, (0, padded_token_num), mode="constant", value=1)
363+
new_model_input.mem_indexes = F.pad(
364+
new_model_input.mem_indexes,
365+
(0, padded_token_num),
366+
mode="constant",
367+
value=self.mem_manager.HOLD_TOKEN_MEMINDEX,
368+
)
369+
new_model_input.b_req_idx = F.pad(
370+
new_model_input.b_req_idx, (0, 1), mode="constant", value=self.req_manager.HOLD_REQUEST_ID
371+
)
372+
new_model_input.b_mtp_index = F.pad(new_model_input.b_mtp_index, (0, 1), mode="constant", value=0)
373+
new_model_input.b_seq_len = F.pad(new_model_input.b_seq_len, (0, 1), mode="constant", value=padded_token_num)
374+
new_model_input.b_ready_cache_len = F.pad(new_model_input.b_ready_cache_len, (0, 1), mode="constant", value=0)
375+
b_q_seq_len = new_model_input.b_seq_len - new_model_input.b_ready_cache_len
376+
new_model_input.b_prefill_start_loc = b_q_seq_len.cumsum(dim=0, dtype=torch.int32) - b_q_seq_len
377+
# 构建新的list, 使用 append 可能会让外面使用的数组引用发生变化,导致错误。
378+
new_model_input.b_prefill_has_output_cpu = [e for e in new_model_input.b_prefill_has_output_cpu] + [False]
379+
new_model_input.prefix_total_token_num = model_input.prefix_total_token_num
380+
381+
# TODO 多模态的参数需要 pad 吗,需要check
382+
383+
# 特殊模型,特殊模式的特殊变量的特殊 padding
384+
if new_model_input.deepseekv3_mtp_draft_input_hiddens is not None:
385+
new_model_input.deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch(
386+
input=new_model_input.deepseekv3_mtp_draft_input_hiddens,
387+
new_batch_size=new_handle_token_num,
388+
)
389+
390+
return new_model_input
391+
335392
def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_batch_size: int):
336393
padded_batch_size = model_output.logits.shape[0]
337394
if padded_batch_size == origin_batch_size:
@@ -346,10 +403,39 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba
346403

347404
return new_model_output
348405

406+
def _create_unpad_prefill_model_output(self, padded_model_output: ModelOutput, origin_handle_token_num: int):
407+
if self.return_all_prompt_logics:
408+
new_model_output = copy.copy(padded_model_output)
409+
new_model_output.logits = new_model_output.logits[0:origin_handle_token_num]
410+
else:
411+
new_model_output = copy.copy(padded_model_output)
412+
# 移除多余的pad 的那个 req 对应的 logics
413+
new_model_output.logits = new_model_output.logits[0:-1]
414+
415+
# 特殊模型,特殊模式的特殊变量的特殊 unpad
416+
if new_model_output.deepseekv3_mtp_main_output_hiddens is not None:
417+
_hidden_states = new_model_output.deepseekv3_mtp_main_output_hiddens
418+
new_model_output.deepseekv3_mtp_main_output_hiddens = _hidden_states[0:origin_handle_token_num]
419+
420+
return new_model_output
421+
349422
def _prefill(
350423
self,
351424
model_input: ModelInput,
352425
):
426+
origin_handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num
427+
428+
is_padded_model_input = False
429+
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=origin_handle_token_num):
430+
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
431+
handle_token_num=origin_handle_token_num
432+
)
433+
if finded_handle_token_num != origin_handle_token_num:
434+
is_padded_model_input = True
435+
model_input = self._create_padded_prefill_model_input(
436+
model_input=model_input, new_handle_token_num=finded_handle_token_num
437+
)
438+
353439
infer_state = self._create_inferstate(model_input)
354440
init_req_to_token_indexes(
355441
req_to_token_indexs=self.req_manager.req_to_token_indexs,
@@ -365,6 +451,10 @@ def _prefill(
365451

366452
infer_state.init_some_extra_state(self, model_input.input_ids)
367453
model_output = self._context_forward(model_input.input_ids, infer_state)
454+
if is_padded_model_input:
455+
model_output = self._create_unpad_prefill_model_output(
456+
model_output, origin_handle_token_num=origin_handle_token_num
457+
)
368458
model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
369459
return model_output
370460

@@ -419,22 +509,45 @@ def _decode(
419509
@final
420510
def _context_forward(self, input_ids, infer_state: InferStateInfo):
421511
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
422-
g_cache_manager.cache_env_in()
423512
cuda_input_ids = input_ids
424513

425514
pre_method = (self.pre_infer.context_forward, self.pre_infer.tpsp_context_forward)[run_mode_index]
426515
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
516+
input_tensors = [input_embs]
427517

428-
for i in range(self.layers_num):
429-
layer = self.layers_infer[i]
430-
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
431-
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
518+
def prefill_func(input_tensors, infer_state):
519+
_input_embs = input_tensors[0]
520+
for i in range(self.layers_num):
521+
layer = self.layers_infer[i]
522+
layer_method = (layer.context_forward, layer.tpsp_context_forward)[run_mode_index]
523+
_input_embs = layer_method(_input_embs, infer_state, self.trans_layers_weight[i])
524+
return [_input_embs]
432525

433-
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
434-
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
526+
handle_token_num = input_ids.shape[0]
435527

436-
g_cache_manager.cache_env_out()
528+
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
529+
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
530+
handle_token_num=handle_token_num
531+
)
532+
if self.prefill_graph.need_capture(handle_token_num=finded_handle_token_num):
533+
output_tensors: List[torch.Tensor] = self.prefill_graph.capture_prefill(
534+
prefill_func=prefill_func,
535+
input_tensors=input_tensors,
536+
infer_state=infer_state,
537+
)
538+
else:
539+
output_tensors: List[torch.Tensor] = self.prefill_graph.replay(
540+
input_tensors=input_tensors, infer_state=infer_state
541+
)
437542

543+
else:
544+
g_cache_manager.cache_env_in()
545+
output_tensors: List[torch.Tensor] = prefill_func(input_tensors, infer_state)
546+
g_cache_manager.cache_env_out()
547+
548+
input_embs = output_tensors[0]
549+
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
550+
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
438551
model_output = ModelOutput(logits=predict_logits)
439552

440553
# 特殊模型特殊模式的额外输出
@@ -449,40 +562,30 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
449562
@final
450563
def _token_forward(self, input_ids, infer_state: InferStateInfo):
451564
run_mode_index = 1 if self.enable_tpsp_mix_mode else 0
452-
g_cache_manager.cache_env_in(
453-
is_cuda_graph=infer_state.is_cuda_graph,
454-
cur_batch_size=infer_state.batch_size,
455-
cuda_graph_max_batch_size=self.graph_max_batch_size,
456-
)
457565
cuda_input_ids = input_ids
458566
pre_method = (self.pre_infer.token_forward, self.pre_infer.tpsp_token_forward)[run_mode_index]
459567
input_embs = pre_method(cuda_input_ids, infer_state, self.pre_post_weight)
460568
for i in range(self.layers_num):
461569
layer = self.layers_infer[i]
462570
layer_method = (layer.token_forward, layer.tpsp_token_forward)[run_mode_index]
463-
input_embs = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
571+
input_embs: torch.Tensor = layer_method(input_embs, infer_state, self.trans_layers_weight[i])
464572

465573
post_method = (self.post_infer.token_forward, self.post_infer.tpsp_token_forward)[run_mode_index]
466-
predict_logits = post_method(input_embs, infer_state, self.pre_post_weight)
574+
predict_logits: torch.Tensor = post_method(input_embs, infer_state, self.pre_post_weight)
467575

468576
if self.is_deepseekv3_mtp_mode:
469-
graph_out_hiddens = g_cache_manager.alloc_tensor(
470-
input_embs.shape,
471-
data_type=input_embs.dtype,
472-
is_graph_out=True,
473-
microbatch_index=infer_state.microbatch_index,
474-
graph_out_key=520,
475-
)
476-
graph_out_hiddens.copy_(input_embs)
477-
478-
g_cache_manager.cache_env_out()
577+
graph_out_hiddens = input_embs.contiguous()
479578

480-
model_output = ModelOutput(logits=predict_logits)
579+
model_output = ModelOutput(logits=predict_logits.contiguous())
481580

482581
# 特殊模型特殊模式的额外输出
483582
if self.is_deepseekv3_mtp_mode:
484583
model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens
485584

585+
# 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
586+
if infer_state.is_cuda_graph:
587+
model_output.to_no_ref_tensor()
588+
486589
return model_output
487590

488591
@torch.no_grad()
@@ -642,24 +745,19 @@ def _overlap_tpsp_context_forward(
642745
)
643746
g_cache_manager.cache_env_out()
644747

645-
model_output = ModelOutput(logits=predict_logits)
646-
model_output1 = ModelOutput(logits=predict_logits1)
748+
model_output = ModelOutput(logits=predict_logits.contiguous())
749+
model_output1 = ModelOutput(logits=predict_logits1.contiguous())
647750

648751
if self.is_deepseekv3_mtp_mode:
649-
model_output.deepseekv3_mtp_main_output_hiddens = input_embs
650-
model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1
752+
model_output.deepseekv3_mtp_main_output_hiddens = input_embs.contiguous()
753+
model_output1.deepseekv3_mtp_main_output_hiddens = input_embs1.contiguous()
651754

652755
return model_output, model_output1
653756

654757
@final
655758
def _overlap_tpsp_token_forward(
656759
self, input_ids, infer_state: InferStateInfo, input_ids1, infer_state1: InferStateInfo
657760
):
658-
g_cache_manager.cache_env_in(
659-
is_cuda_graph=infer_state.is_cuda_graph,
660-
cur_batch_size=infer_state.batch_size,
661-
cuda_graph_max_batch_size=self.graph_max_batch_size,
662-
)
663761
input_embs, input_embs1 = self.pre_infer.overlap_tpsp_token_forward(
664762
input_ids, input_ids1, infer_state, infer_state1, self.pre_post_weight
665763
)
@@ -674,32 +772,20 @@ def _overlap_tpsp_token_forward(
674772
)
675773

676774
if self.is_deepseekv3_mtp_mode:
677-
graph_out_hiddens = g_cache_manager.alloc_tensor(
678-
input_embs.shape,
679-
data_type=input_embs.dtype,
680-
is_graph_out=True,
681-
microbatch_index=0,
682-
graph_out_key=520,
683-
)
684-
graph_out_hiddens.copy_(input_embs)
685-
graph_out_hiddens1 = g_cache_manager.alloc_tensor(
686-
input_embs1.shape,
687-
data_type=input_embs1.dtype,
688-
is_graph_out=True,
689-
microbatch_index=1,
690-
graph_out_key=520,
691-
)
692-
graph_out_hiddens1.copy_(input_embs1)
775+
graph_out_hiddens = input_embs.contiguous()
776+
graph_out_hiddens1 = input_embs1.contiguous()
693777

694-
g_cache_manager.cache_env_out()
695-
696-
model_output = ModelOutput(logits=predict_logits)
697-
model_output1 = ModelOutput(logits=predict_logits1)
778+
model_output = ModelOutput(logits=predict_logits.contiguous())
779+
model_output1 = ModelOutput(logits=predict_logits1.contiguous())
698780

699781
if self.is_deepseekv3_mtp_mode:
700782
model_output.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens
701783
model_output1.deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1
702784

785+
if infer_state.is_cuda_graph:
786+
model_output.to_no_ref_tensor()
787+
model_output1.to_no_ref_tensor()
788+
703789
return model_output, model_output1
704790

705791
@final

lightllm/common/basemodel/batch_objs.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Optional
44
from typing import List
55
from lightllm.utils.envs_utils import enable_diverse_mode_gqa_decode_fast_kernel
6+
from lightllm.utils.tensor_utils import tensor_to_no_ref_tensor
67

78

89
@dataclass
@@ -88,3 +89,8 @@ class ModelOutput:
8889
# 输出最后一层的hidden state 状态用于 draft 模型的 deepseekv3_mtp_draft_input_hiddens
8990
# 输入
9091
deepseekv3_mtp_main_output_hiddens: Optional[torch.Tensor] = None
92+
93+
def to_no_ref_tensor(self):
94+
self.logits = tensor_to_no_ref_tensor(self.logits)
95+
if self.deepseekv3_mtp_main_output_hiddens is not None:
96+
self.deepseekv3_mtp_main_output_hiddens = tensor_to_no_ref_tensor(self.deepseekv3_mtp_main_output_hiddens)

0 commit comments

Comments
 (0)