77
88from lightllm .common .basemodel .layer_weights .hf_load_utils import load_hf_weights
99from lightllm .common .basemodel .infer_struct import InferStateInfo
10- from lightllm .common .basemodel .splitfuse_infer_struct import SplitFuseInferStateInfo
1110from lightllm .common .mem_manager import MemoryManager
1211from lightllm .common .req_manager import ReqManager
1312from lightllm .common .infer_utils import init_req_to_token_indexes
1413from lightllm .common .build_utils import repair_config
1514from lightllm .common .basemodel .triton_kernel .copy_kv_index_to_req import copy_kv_index_to_req
16- from lightllm .common .basemodel .triton_kernel .splitfuse_copy_kv_index_to_req import splitfuse_copy_kv_index_to_req
1715from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
1816from lightllm .common .basemodel .cuda_graph import CudaGraph
1917from lightllm .common .quantization import Quantcfg
@@ -36,7 +34,6 @@ class TpPartBaseModel:
3634
3735 # infer state class
3836 infer_state_class = InferStateInfo
39- splitfuse_infer_state_class = SplitFuseInferStateInfo
4037
4138 def __init__ (self , kvargs ):
4239 self .run_mode = kvargs ["run_mode" ]
@@ -57,6 +54,8 @@ def __init__(self, kvargs):
5754 self .return_all_prompt_logics = kvargs .get ("return_all_prompt_logics" , False )
5855 assert not (self .is_token_healing and self .return_all_prompt_logics ), "can not be true in same time"
5956 self .use_dynamic_prompt_cache = kvargs .get ("use_dynamic_prompt_cache" , False )
57+ enable_chunked_prefill = kvargs .get ("enable_chunked_prefill" , False ) # chunked prefill is default on.
58+ self .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache or enable_chunked_prefill
6059 self .data_type = kvargs .get ("data_type" , "float16" )
6160 self .graph_max_batch_size = kvargs .get ("graph_max_batch_size" , 16 )
6261 self .graph_max_len_in_batch = kvargs .get ("graph_max_len_in_batch" , 8192 )
@@ -368,81 +367,6 @@ def _decode(
368367 predict_logics = self ._token_forward (input_ids , infer_state )
369368 return predict_logics
370369
371- @torch .no_grad ()
372- def splitfuse_forward (
373- self ,
374- input_ids ,
375- mem_indexes ,
376- decode_req_num ,
377- decode_total_token_num ,
378- decode_b_req_idx : torch .Tensor ,
379- decode_b_start_loc : torch .Tensor ,
380- decode_b_seq_len : torch .Tensor ,
381- decode_max_len_in_batch ,
382- prefill_req_num ,
383- prefill_b_req_idx : torch .Tensor ,
384- prefill_b_split_start_loc : torch .Tensor ,
385- prefill_b_split_ready_cache_len : torch .Tensor ,
386- prefill_max_split_seq_len_in_batch ,
387- prefill_b_seq_len : torch .Tensor ,
388- ):
389-
390- infer_state = self .splitfuse_infer_state_class ()
391- infer_state .use_dynamic_prompt_cache = self .use_dynamic_prompt_cache
392- infer_state .batch_size = decode_req_num + prefill_req_num
393-
394- infer_state .decode_req_num = decode_req_num
395- infer_state .decode_total_token_num = decode_total_token_num
396- infer_state .decode_b_req_idx = decode_b_req_idx
397- infer_state .decode_b_start_loc = decode_b_start_loc
398- infer_state .decode_b_seq_len = decode_b_seq_len
399- infer_state .decode_max_len_in_batch = decode_max_len_in_batch
400-
401- infer_state .prefill_req_num = prefill_req_num
402- infer_state .prefill_b_req_idx = prefill_b_req_idx
403- infer_state .prefill_b_split_start_loc = prefill_b_split_start_loc
404- infer_state .prefill_b_split_ready_cache_len = prefill_b_split_ready_cache_len
405- infer_state .prefill_max_split_seq_len_in_batch = prefill_max_split_seq_len_in_batch
406- infer_state .prefill_b_seq_len = prefill_b_seq_len
407- # infer_state.event = [torch.cuda.Event() for _ in range(self.layers_num)]
408-
409- infer_state .mem_manager = self .mem_manager
410- infer_state .req_manager = self .req_manager
411-
412- alloc_size = len (input_ids )
413- infer_state .mem_is_contiguous = False
414- infer_state .mem_index = mem_indexes
415- infer_state .kv_buffer = torch .empty (
416- (alloc_size , self .tp_k_head_num_ + self .tp_v_head_num_ , self .head_dim_ ),
417- dtype = self .data_type ,
418- device = "cuda" ,
419- )
420-
421- # decode 部分
422- if decode_req_num != 0 :
423- copy_kv_index_to_req (
424- self .req_manager .req_to_token_indexs ,
425- decode_b_req_idx ,
426- decode_b_seq_len ,
427- infer_state .mem_index [0 :decode_req_num ],
428- )
429-
430- # split prefill 部分
431- if prefill_req_num != 0 :
432- splitfuse_copy_kv_index_to_req (
433- self .req_manager .req_to_token_indexs ,
434- prefill_b_req_idx ,
435- prefill_b_split_ready_cache_len ,
436- prefill_b_seq_len ,
437- infer_state .mem_index [decode_req_num :],
438- )
439-
440- infer_state .init_some_extra_state (self , input_ids )
441- infer_state .create_inner_decode_infer_status ()
442- infer_state .create_inner_prefill_infer_status ()
443- predict_logics = self ._splitfuse_forward (input_ids , infer_state )
444- return predict_logics
445-
446370 @final
447371 def _context_forward (self , input_ids , infer_state : InferStateInfo ):
448372 g_cache_manager .cache_env_in ()
@@ -469,17 +393,6 @@ def _token_forward(self, input_ids, infer_state: InferStateInfo):
469393 g_cache_manager .cache_env_out ()
470394 return predict_logics
471395
472- @final
473- def _splitfuse_forward (self , input_ids , infer_state : SplitFuseInferStateInfo ):
474- g_cache_manager .cache_env_in ()
475- cuda_input_ids = input_ids
476- input_embs = self .pre_infer .splitfuse_forward (cuda_input_ids , infer_state , self .pre_post_weight )
477- for i in range (0 , self .layers_num ):
478- input_embs = self .layers_infer [i ].splitfuse_forward (input_embs , infer_state , self .trans_layers_weight [i ])
479- predict_logics = self .post_infer .splitfuse_forward (input_embs , infer_state , self .pre_post_weight )
480- g_cache_manager .cache_env_out ()
481- return predict_logics
482-
483396 @final
484397 @torch .no_grad ()
485398 def _check_max_len_infer (self ):
0 commit comments