1717from lightllm .common .quantization import Quantcfg
1818from lightllm .utils .log_utils import init_logger
1919from lightllm .utils .dist_utils import get_dp_world_size
20+ from lightllm .utils .envs_utils import get_env_start_args
21+ from lightllm .distributed .communication_op import CustomProcessGroup , dist_group_manager
22+ from lightllm .common .basemodel .microbatch_overlap_objs import DecodeMicroBatch
2023
2124logger = init_logger (__name__ )
2225
@@ -53,16 +56,15 @@ def __init__(self, kvargs):
5356 self .return_all_prompt_logics = kvargs .get ("return_all_prompt_logics" , False )
5457 assert not (self .is_token_healing and self .return_all_prompt_logics ), "can not be true in same time"
5558 self .use_dynamic_prompt_cache = kvargs .get ("use_dynamic_prompt_cache" , False )
56- enable_chunked_prefill = kvargs .get ("enable_chunked_prefill" , False ) # chunked prefill is default on.
57- self .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache or enable_chunked_prefill
5859 self .data_type = kvargs .get ("data_type" , "float16" )
5960 self .graph_max_batch_size = kvargs .get ("graph_max_batch_size" , 16 )
6061 self .graph_max_len_in_batch = kvargs .get ("graph_max_len_in_batch" , 8192 )
6162 self .disable_cudagraph = kvargs .get ("disable_cudagraph" , False )
62- self .quant_type = kvargs .get ("quant_type" , None )
63+ self .quant_type = kvargs .get ("quant_type" , "none" )
6364 self .quant_cfg_path = kvargs .get ("quant_cfg" , None )
6465 self .mem_fraction = kvargs .get ("mem_fraction" , 0.9 )
6566 self .tp_world_size_ = get_dp_world_size ()
67+ self .enable_tpsp_mix_mode = get_env_start_args ().enable_tpsp_mix_mode
6668
6769 self ._init_datatype ()
6870 self ._init_config ()
@@ -98,7 +100,6 @@ def _init_config(self):
98100 repair_config (self .config , same_names = ["num_hidden_layers" , "n_layer" ])
99101 if self .finetune_config :
100102 self .config ["vocab_size" ] = self .finetune_config .vocab_size
101-
102103 return
103104
104105 @final
@@ -207,7 +208,10 @@ def _init_cudagraph(self):
207208 None if self .disable_cudagraph else CudaGraph (self .graph_max_batch_size , self .graph_max_len_in_batch )
208209 )
209210 if self .graph is not None :
210- self .graph .warmup (self )
211+ if get_env_start_args ().enable_decode_microbatch_overlap :
212+ self .graph .warmup_overlap (self )
213+ else :
214+ self .graph .warmup (self )
211215
212216 def _init_custom (self ):
213217 pass
@@ -296,6 +300,7 @@ def _prefill(
296300 dtype = self .data_type ,
297301 device = "cuda" ,
298302 )
303+ infer_state .dist_group = dist_group_manager .get_default_group ()
299304
300305 init_req_to_token_indexes (
301306 self .req_manager .req_to_token_indexs ,
@@ -346,6 +351,7 @@ def _decode(
346351 dtype = self .data_type ,
347352 device = "cuda" ,
348353 )
354+ infer_state .dist_group = dist_group_manager .get_default_group ()
349355 copy_kv_index_to_req (self .req_manager .req_to_token_indexs , b_req_idx , b_seq_len , infer_state .mem_index )
350356
351357 infer_state .init_some_extra_state (self , input_ids )
@@ -359,32 +365,143 @@ def _decode(
359365 predict_logics = self ._token_forward (input_ids , infer_state )
360366 return predict_logics
361367
368+ @torch .no_grad ()
369+ def microbatch_overlap_decode (self , batch : DecodeMicroBatch , batch1 : DecodeMicroBatch ):
370+ assert batch .batch_size == batch1 .batch_size
371+ assert batch .mem_indexes .is_cuda
372+ assert batch1 .mem_indexes .is_cuda
373+ input_ids , input_ids1 = batch .input_ids , batch1 .input_ids
374+
375+ def create_inferstate (cur_batch : DecodeMicroBatch , batch_index ):
376+ infer_state = self .infer_state_class ()
377+ infer_state .is_prefill = False
378+ infer_state .batch_size = cur_batch .batch_size
379+ infer_state .total_token_num = cur_batch .total_token_num
380+ infer_state .max_len_in_batch = cur_batch .max_len_in_batch
381+ infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
382+ assert cur_batch .b_req_idx .shape [0 ] == cur_batch .b_start_loc .shape [0 ] == cur_batch .b_seq_len .shape [0 ]
383+ infer_state .b_req_idx = cur_batch .b_req_idx
384+ infer_state .b_start_loc = cur_batch .b_start_loc
385+ infer_state .b_seq_len = cur_batch .b_seq_len
386+ infer_state .multimodal_params = None
387+ infer_state .microbatch_index = batch_index
388+
389+ infer_state .mem_manager = self .mem_manager
390+ infer_state .req_manager = self .req_manager
391+
392+ # 在使用 cuda graph 特性的时候,必须保证每次推理的流程一致
393+ # 所以不再使用分配连续的mem带来的优化,保证推理流程的一致
394+ infer_state .mem_is_contiguous = False
395+ infer_state .mem_index = cur_batch .mem_indexes
396+ infer_state .kv_buffer = torch .empty (
397+ (cur_batch .batch_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
398+ dtype = self .data_type ,
399+ device = "cuda" ,
400+ )
401+ infer_state .dist_group = dist_group_manager .get_group (batch_index )
402+ copy_kv_index_to_req (
403+ self .req_manager .req_to_token_indexs , cur_batch .b_req_idx , cur_batch .b_seq_len , infer_state .mem_index
404+ )
405+ return infer_state
406+
407+ infer_state = create_inferstate (batch , 0 )
408+ infer_state1 = create_inferstate (batch1 , 1 )
409+
410+ infer_state .init_some_extra_state (self , input_ids )
411+ infer_state1 .init_some_extra_state (self , input_ids1 )
412+
413+ batch_size = batch .batch_size
414+ max_len_in_batch = max (batch .max_len_in_batch , batch1 .max_len_in_batch )
415+
416+ if self .graph is not None and self .graph .can_run (batch_size , max_len_in_batch ):
417+ if self .graph .need_capture (batch_size ):
418+ infer_state .is_cuda_graph = True
419+ infer_state1 .is_cuda_graph = True
420+
421+ predict_logics , predict_logics1 = self .graph .capture_decode (
422+ self ._overlap_tpsp_token_forward ,
423+ input_ids ,
424+ infer_state ,
425+ input_ids1 = input_ids1 ,
426+ infer_state1 = infer_state1 ,
427+ )
428+ else :
429+ predict_logics , predict_logics1 = self .graph .replay (
430+ input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
431+ )
432+ else :
433+ predict_logics , predict_logics1 = self ._overlap_tpsp_token_forward (
434+ input_ids , infer_state , input_ids1 = input_ids1 , infer_state1 = infer_state1
435+ )
436+ return predict_logics , predict_logics1
437+
362438 @final
363439 def _context_forward (self , input_ids , infer_state : InferStateInfo ):
440+ run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
364441 g_cache_manager .cache_env_in ()
365442 cuda_input_ids = input_ids
366- input_embs = self .pre_infer .context_forward (cuda_input_ids , infer_state , self .pre_post_weight )
367- for i in range (0 , self .layers_num ):
368- input_embs = self .layers_infer [i ].context_forward (input_embs , infer_state , self .trans_layers_weight [i ])
369- predict_logics = self .post_infer .token_forward (input_embs , infer_state , self .pre_post_weight )
443+
444+ pre_method = (self .pre_infer .context_forward , self .pre_infer .tpsp_context_forward )[run_mode_index ]
445+ input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
446+
447+ for i in range (self .layers_num ):
448+ layer = self .layers_infer [i ]
449+ layer_method = (layer .context_forward , layer .tpsp_context_forward )[run_mode_index ]
450+ input_embs = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
451+
452+ post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
453+ predict_logics = post_method (input_embs , infer_state , self .pre_post_weight )
454+
370455 g_cache_manager .cache_env_out ()
371456 return predict_logics
372457
373458 @final
374459 def _token_forward (self , input_ids , infer_state : InferStateInfo ):
460+ run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
375461 g_cache_manager .cache_env_in (
376462 is_cuda_graph = infer_state .is_cuda_graph ,
377463 cur_batch_size = infer_state .batch_size ,
378464 cuda_graph_max_batch_size = self .graph_max_batch_size ,
379465 )
380466 cuda_input_ids = input_ids
381- input_embs = self .pre_infer .token_forward (cuda_input_ids , infer_state , self .pre_post_weight )
382- for i in range (0 , self .layers_num ):
383- input_embs = self .layers_infer [i ].token_forward (input_embs , infer_state , self .trans_layers_weight [i ])
384- predict_logics = self .post_infer .token_forward (input_embs , infer_state , self .pre_post_weight )
467+ pre_method = (self .pre_infer .token_forward , self .pre_infer .tpsp_token_forward )[run_mode_index ]
468+ input_embs = pre_method (cuda_input_ids , infer_state , self .pre_post_weight )
469+ for i in range (self .layers_num ):
470+ layer = self .layers_infer [i ]
471+ layer_method = (layer .token_forward , layer .tpsp_token_forward )[run_mode_index ]
472+ input_embs = layer_method (input_embs , infer_state , self .trans_layers_weight [i ])
473+
474+ post_method = (self .post_infer .token_forward , self .post_infer .tpsp_token_forward )[run_mode_index ]
475+ predict_logics = post_method (input_embs , infer_state , self .pre_post_weight )
476+
385477 g_cache_manager .cache_env_out ()
386478 return predict_logics
387479
480+ @final
481+ def _overlap_tpsp_token_forward (
482+ self , input_ids , infer_state : InferStateInfo , input_ids1 , infer_state1 : InferStateInfo
483+ ):
484+ g_cache_manager .cache_env_in (
485+ is_cuda_graph = infer_state .is_cuda_graph ,
486+ cur_batch_size = infer_state .batch_size ,
487+ cuda_graph_max_batch_size = self .graph_max_batch_size ,
488+ )
489+ input_embs , input_embs1 = self .pre_infer .overlap_tpsp_token_forward (
490+ input_ids , input_ids1 , infer_state , infer_state1 , self .pre_post_weight
491+ )
492+
493+ for i in range (self .layers_num ):
494+ input_embs , input_embs1 = self .layers_infer [i ].overlap_tpsp_token_forward (
495+ input_embs , input_embs1 , infer_state , infer_state1 , self .trans_layers_weight [i ]
496+ )
497+
498+ predict_logics , predict_logics1 = self .post_infer .overlap_tpsp_token_forward (
499+ input_embs , input_embs1 , infer_state , infer_state1 , self .pre_post_weight
500+ )
501+
502+ g_cache_manager .cache_env_out ()
503+ return predict_logics , predict_logics1
504+
388505 @final
389506 @torch .no_grad ()
390507 def _check_max_len_infer (self ):
0 commit comments