66import json
77import torch
88import torch .nn .functional as F
9- from typing import final
9+ from typing import final , List
1010from tqdm import tqdm
1111
1212from lightllm .common .basemodel .layer_weights .hf_load_utils import load_hf_weights
1919from lightllm .common .basemodel .triton_kernel .copy_kv_index_to_req import copy_kv_index_to_req
2020from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
2121from lightllm .common .basemodel .cuda_graph import CudaGraph
22+ from lightllm .common .basemodel .prefill_cuda_graph import PrefillCudaGraph
2223from lightllm .common .quantization import Quantcfg
2324from lightllm .common .basemodel .triton_kernel .gather_token_id import gather_token
2425from lightllm .utils .log_utils import init_logger
@@ -89,6 +90,7 @@ def __init__(self, kvargs):
8990 self .enable_tpsp_mix_mode = get_env_start_args ().enable_tpsp_mix_mode
9091
9192 self .is_deepseekv3_mtp_mode = self .args .mtp_mode in ["deepseekv3_vanilla" , "deepseekv3_eagle" ]
93+ self .prefill_graph : PrefillCudaGraph = None
9294
9395 self ._init_config ()
9496 self ._verify_must ()
@@ -115,6 +117,7 @@ def __init__(self, kvargs):
115117 # wait必须在init cudagraph 之前,避免错误捕获
116118 self ._wait_other_modules_ready ()
117119 self ._init_cudagraph ()
120+ self ._init_prefill_cuda_graph ()
118121 self ._check_max_len_infer ()
119122 torch .cuda .empty_cache ()
120123 set_model_init_status (True )
@@ -240,6 +243,18 @@ def _init_cudagraph(self):
240243 else :
241244 self .graph .warmup (self )
242245
246+ def _init_prefill_cuda_graph (self ):
247+ self .prefill_graph = (
248+ None
249+ if not get_env_start_args ().enable_prefill_cudagraph
250+ else PrefillCudaGraph (decode_cuda_graph = self .graph )
251+ )
252+ if self .prefill_graph is not None :
253+ if get_env_start_args ().enable_prefill_microbatch_overlap :
254+ self .prefill_graph .warmup_overlap (self )
255+ else :
256+ self .prefill_graph .warmup (self )
257+
243258 def _init_custom (self ):
244259 pass
245260
@@ -332,6 +347,48 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
332347
333348 return new_model_input
334349
350+ def _create_padded_prefill_model_input (self , model_input : ModelInput , new_handle_token_num : int ):
351+ assert model_input .total_token_num - model_input .prefix_total_token_num < new_handle_token_num
352+
353+ padded_token_num = new_handle_token_num - (model_input .total_token_num - model_input .prefix_total_token_num )
354+ assert padded_token_num > 0
355+ new_model_input = copy .copy (model_input )
356+ new_model_input .batch_size = model_input .batch_size + 1
357+ new_model_input .total_token_num += padded_token_num
358+ new_model_input .max_len_in_batch = max (padded_token_num , model_input .max_len_in_batch )
359+ new_model_input .max_q_seq_len = max (padded_token_num , model_input .max_q_seq_len )
360+ new_model_input .max_kv_seq_len = max (padded_token_num , model_input .max_kv_seq_len )
361+ new_model_input .max_cache_len = max (0 , model_input .max_cache_len )
362+ new_model_input .input_ids = F .pad (new_model_input .input_ids , (0 , padded_token_num ), mode = "constant" , value = 1 )
363+ new_model_input .mem_indexes = F .pad (
364+ new_model_input .mem_indexes ,
365+ (0 , padded_token_num ),
366+ mode = "constant" ,
367+ value = self .mem_manager .HOLD_TOKEN_MEMINDEX ,
368+ )
369+ new_model_input .b_req_idx = F .pad (
370+ new_model_input .b_req_idx , (0 , 1 ), mode = "constant" , value = self .req_manager .HOLD_REQUEST_ID
371+ )
372+ new_model_input .b_mtp_index = F .pad (new_model_input .b_mtp_index , (0 , 1 ), mode = "constant" , value = 0 )
373+ new_model_input .b_seq_len = F .pad (new_model_input .b_seq_len , (0 , 1 ), mode = "constant" , value = padded_token_num )
374+ new_model_input .b_ready_cache_len = F .pad (new_model_input .b_ready_cache_len , (0 , 1 ), mode = "constant" , value = 0 )
375+ b_q_seq_len = new_model_input .b_seq_len - new_model_input .b_ready_cache_len
376+ new_model_input .b_prefill_start_loc = b_q_seq_len .cumsum (dim = 0 , dtype = torch .int32 ) - b_q_seq_len
377+ # 构建新的list, 使用 append 可能会让外面使用的数组引用发生变化,导致错误。
378+ new_model_input .b_prefill_has_output_cpu = [e for e in new_model_input .b_prefill_has_output_cpu ] + [False ]
379+ new_model_input .prefix_total_token_num = model_input .prefix_total_token_num
380+
381+ # TODO 多模态的参数需要 pad 吗,需要check
382+
383+ # 特殊模型,特殊模式的特殊变量的特殊 padding
384+ if new_model_input .deepseekv3_mtp_draft_input_hiddens is not None :
385+ new_model_input .deepseekv3_mtp_draft_input_hiddens = pad2dim_tensor_to_new_batch (
386+ input = new_model_input .deepseekv3_mtp_draft_input_hiddens ,
387+ new_batch_size = new_handle_token_num ,
388+ )
389+
390+ return new_model_input
391+
335392 def _create_unpad_decode_model_output (self , model_output : ModelOutput , origin_batch_size : int ):
336393 padded_batch_size = model_output .logits .shape [0 ]
337394 if padded_batch_size == origin_batch_size :
@@ -346,10 +403,39 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba
346403
347404 return new_model_output
348405
406+ def _create_unpad_prefill_model_output (self , padded_model_output : ModelOutput , origin_handle_token_num : int ):
407+ if self .return_all_prompt_logics :
408+ new_model_output = copy .copy (padded_model_output )
409+ new_model_output .logits = new_model_output .logits [0 :origin_handle_token_num ]
410+ else :
411+ new_model_output = copy .copy (padded_model_output )
412+ # 移除多余的pad 的那个 req 对应的 logics
413+ new_model_output .logits = new_model_output .logits [0 :- 1 ]
414+
415+ # 特殊模型,特殊模式的特殊变量的特殊 unpad
416+ if new_model_output .deepseekv3_mtp_main_output_hiddens is not None :
417+ _hidden_states = new_model_output .deepseekv3_mtp_main_output_hiddens
418+ new_model_output .deepseekv3_mtp_main_output_hiddens = _hidden_states [0 :origin_handle_token_num ]
419+
420+ return new_model_output
421+
349422 def _prefill (
350423 self ,
351424 model_input : ModelInput ,
352425 ):
426+ origin_handle_token_num = model_input .total_token_num - model_input .prefix_total_token_num
427+
428+ is_padded_model_input = False
429+ if self .prefill_graph is not None and self .prefill_graph .can_run (handle_token_num = origin_handle_token_num ):
430+ finded_handle_token_num = self .prefill_graph .find_closest_graph_handle_token_num (
431+ handle_token_num = origin_handle_token_num
432+ )
433+ if finded_handle_token_num != origin_handle_token_num :
434+ is_padded_model_input = True
435+ model_input = self ._create_padded_prefill_model_input (
436+ model_input = model_input , new_handle_token_num = finded_handle_token_num
437+ )
438+
353439 infer_state = self ._create_inferstate (model_input )
354440 init_req_to_token_indexes (
355441 req_to_token_indexs = self .req_manager .req_to_token_indexs ,
@@ -365,6 +451,10 @@ def _prefill(
365451
366452 infer_state .init_some_extra_state (self , model_input .input_ids )
367453 model_output = self ._context_forward (model_input .input_ids , infer_state )
454+ if is_padded_model_input :
455+ model_output = self ._create_unpad_prefill_model_output (
456+ model_output , origin_handle_token_num = origin_handle_token_num
457+ )
368458 model_output .prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
369459 return model_output
370460
@@ -419,22 +509,45 @@ def _decode(
419509 @final
420510 def _context_forward (self , input_ids , infer_state : InferStateInfo ):
421511 run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
422- g_cache_manager .cache_env_in ()
423512 cuda_input_ids = input_ids
424513
425514 pre_method = (self .pre_infer .context_forward , self .pre_infer .tpsp_context_forward )[run_mode_index ]
426515 input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
516+ input_tensors = [input_embs ]
427517
428- for i in range (self .layers_num ):
429- layer = self .layers_infer [i ]
430- layer_method = (layer .context_forward , layer .tpsp_context_forward )[run_mode_index ]
431- input_embs = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
518+ def prefill_func (input_tensors , infer_state ):
519+ _input_embs = input_tensors [0 ]
520+ for i in range (self .layers_num ):
521+ layer = self .layers_infer [i ]
522+ layer_method = (layer .context_forward , layer .tpsp_context_forward )[run_mode_index ]
523+ _input_embs = layer_method (_input_embs , infer_state , self .trans_layers_weight [i ])
524+ return [_input_embs ]
432525
433- post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
434- predict_logits = post_method (input_embs , infer_state , self .pre_post_weight )
526+ handle_token_num = input_ids .shape [0 ]
435527
436- g_cache_manager .cache_env_out ()
528+ if self .prefill_graph is not None and self .prefill_graph .can_run (handle_token_num = handle_token_num ):
529+ finded_handle_token_num = self .prefill_graph .find_closest_graph_handle_token_num (
530+ handle_token_num = handle_token_num
531+ )
532+ if self .prefill_graph .need_capture (handle_token_num = finded_handle_token_num ):
533+ output_tensors : List [torch .Tensor ] = self .prefill_graph .capture_prefill (
534+ prefill_func = prefill_func ,
535+ input_tensors = input_tensors ,
536+ infer_state = infer_state ,
537+ )
538+ else :
539+ output_tensors : List [torch .Tensor ] = self .prefill_graph .replay (
540+ input_tensors = input_tensors , infer_state = infer_state
541+ )
437542
543+ else :
544+ g_cache_manager .cache_env_in ()
545+ output_tensors : List [torch .Tensor ] = prefill_func (input_tensors , infer_state )
546+ g_cache_manager .cache_env_out ()
547+
548+ input_embs = output_tensors [0 ]
549+ post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
550+ predict_logits = post_method (input_embs , infer_state , self .pre_post_weight )
438551 model_output = ModelOutput (logits = predict_logits )
439552
440553 # 特殊模型特殊模式的额外输出
@@ -449,40 +562,30 @@ def _context_forward(self, input_ids, infer_state: InferStateInfo):
449562 @final
450563 def _token_forward (self , input_ids , infer_state : InferStateInfo ):
451564 run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
452- g_cache_manager .cache_env_in (
453- is_cuda_graph = infer_state .is_cuda_graph ,
454- cur_batch_size = infer_state .batch_size ,
455- cuda_graph_max_batch_size = self .graph_max_batch_size ,
456- )
457565 cuda_input_ids = input_ids
458566 pre_method = (self .pre_infer .token_forward , self .pre_infer .tpsp_token_forward )[run_mode_index ]
459567 input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
460568 for i in range (self .layers_num ):
461569 layer = self .layers_infer [i ]
462570 layer_method = (layer .token_forward , layer .tpsp_token_forward )[run_mode_index ]
463- input_embs = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
571+ input_embs : torch . Tensor = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
464572
465573 post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
466- predict_logits = post_method (input_embs , infer_state , self .pre_post_weight )
574+ predict_logits : torch . Tensor = post_method (input_embs , infer_state , self .pre_post_weight )
467575
468576 if self .is_deepseekv3_mtp_mode :
469- graph_out_hiddens = g_cache_manager .alloc_tensor (
470- input_embs .shape ,
471- data_type = input_embs .dtype ,
472- is_graph_out = True ,
473- microbatch_index = infer_state .microbatch_index ,
474- graph_out_key = 520 ,
475- )
476- graph_out_hiddens .copy_ (input_embs )
477-
478- g_cache_manager .cache_env_out ()
577+ graph_out_hiddens = input_embs .contiguous ()
479578
480- model_output = ModelOutput (logits = predict_logits )
579+ model_output = ModelOutput (logits = predict_logits . contiguous () )
481580
482581 # 特殊模型特殊模式的额外输出
483582 if self .is_deepseekv3_mtp_mode :
484583 model_output .deepseekv3_mtp_main_output_hiddens = graph_out_hiddens
485584
585+ # 在 cuda graph 模式下,输出需要转为 no ref tensor, 加强mem pool 的复用,降低显存的使用。
586+ if infer_state .is_cuda_graph :
587+ model_output .to_no_ref_tensor ()
588+
486589 return model_output
487590
488591 @torch .no_grad ()
@@ -642,24 +745,19 @@ def _overlap_tpsp_context_forward(
642745 )
643746 g_cache_manager .cache_env_out ()
644747
645- model_output = ModelOutput (logits = predict_logits )
646- model_output1 = ModelOutput (logits = predict_logits1 )
748+ model_output = ModelOutput (logits = predict_logits . contiguous () )
749+ model_output1 = ModelOutput (logits = predict_logits1 . contiguous () )
647750
648751 if self .is_deepseekv3_mtp_mode :
649- model_output .deepseekv3_mtp_main_output_hiddens = input_embs
650- model_output1 .deepseekv3_mtp_main_output_hiddens = input_embs1
752+ model_output .deepseekv3_mtp_main_output_hiddens = input_embs . contiguous ()
753+ model_output1 .deepseekv3_mtp_main_output_hiddens = input_embs1 . contiguous ()
651754
652755 return model_output , model_output1
653756
654757 @final
655758 def _overlap_tpsp_token_forward (
656759 self , input_ids , infer_state : InferStateInfo , input_ids1 , infer_state1 : InferStateInfo
657760 ):
658- g_cache_manager .cache_env_in (
659- is_cuda_graph = infer_state .is_cuda_graph ,
660- cur_batch_size = infer_state .batch_size ,
661- cuda_graph_max_batch_size = self .graph_max_batch_size ,
662- )
663761 input_embs , input_embs1 = self .pre_infer .overlap_tpsp_token_forward (
664762 input_ids , input_ids1 , infer_state , infer_state1 , self .pre_post_weight
665763 )
@@ -674,32 +772,20 @@ def _overlap_tpsp_token_forward(
674772 )
675773
676774 if self .is_deepseekv3_mtp_mode :
677- graph_out_hiddens = g_cache_manager .alloc_tensor (
678- input_embs .shape ,
679- data_type = input_embs .dtype ,
680- is_graph_out = True ,
681- microbatch_index = 0 ,
682- graph_out_key = 520 ,
683- )
684- graph_out_hiddens .copy_ (input_embs )
685- graph_out_hiddens1 = g_cache_manager .alloc_tensor (
686- input_embs1 .shape ,
687- data_type = input_embs1 .dtype ,
688- is_graph_out = True ,
689- microbatch_index = 1 ,
690- graph_out_key = 520 ,
691- )
692- graph_out_hiddens1 .copy_ (input_embs1 )
775+ graph_out_hiddens = input_embs .contiguous ()
776+ graph_out_hiddens1 = input_embs1 .contiguous ()
693777
694- g_cache_manager .cache_env_out ()
695-
696- model_output = ModelOutput (logits = predict_logits )
697- model_output1 = ModelOutput (logits = predict_logits1 )
778+ model_output = ModelOutput (logits = predict_logits .contiguous ())
779+ model_output1 = ModelOutput (logits = predict_logits1 .contiguous ())
698780
699781 if self .is_deepseekv3_mtp_mode :
700782 model_output .deepseekv3_mtp_main_output_hiddens = graph_out_hiddens
701783 model_output1 .deepseekv3_mtp_main_output_hiddens = graph_out_hiddens1
702784
785+ if infer_state .is_cuda_graph :
786+ model_output .to_no_ref_tensor ()
787+ model_output1 .to_no_ref_tensor ()
788+
703789 return model_output , model_output1
704790
705791 @final
0 commit comments