11import os
22
33# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
4+ import gc
45import copy
56import json
67import torch
@@ -391,6 +392,71 @@ def _decode(
391392
392393 return model_output
393394
395+ def _build_prefill_model_input (
396+ self , input_len : int , random_token : bool = False , include_special : bool = False
397+ ) -> ModelInput :
398+ dummy_input_ids = (
399+ torch .randint (0 , 10000 , (input_len ,), dtype = torch .int32 , device = "cuda" )
400+ if random_token
401+ else torch .ones (input_len , dtype = torch .int32 , device = "cuda" )
402+ )
403+ b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda" )
404+ mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )).cuda ()
405+ b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
406+ b_seq_len [:] = input_len
407+ b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
408+ total_token_num = input_len
409+ b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
410+
411+ special_kwargs = {}
412+ if include_special :
413+ special_kwargs .update (self ._gen_special_model_input (total_token_num ))
414+
415+ model_input = ModelInput (
416+ batch_size = 1 ,
417+ total_token_num = total_token_num ,
418+ max_len_in_batch = input_len ,
419+ input_ids = dummy_input_ids ,
420+ mem_indexes = mem_indexes ,
421+ b_req_idx = b_req_idx ,
422+ b_seq_len = b_seq_len ,
423+ b_mtp_index = b_mtp_index ,
424+ is_prefill = True ,
425+ b_ready_cache_len = b_ready_cache_len ,
426+ multimodal_params = [],
427+ ** special_kwargs ,
428+ )
429+ return model_input
430+
431+ def _build_padded_prefill_hold_model_input (self , prefill_input_len : int , batch_size : int ) -> ModelInput :
432+ dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda" )
433+ b_req_idx = torch .tensor (
434+ [self .req_manager .HOLD_REQUEST_ID for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
435+ )
436+ mem_indexes = torch .tensor (
437+ [self .mem_manager .HOLD_TOKEN_MEMINDEX for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
438+ )
439+ b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
440+ b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
441+ total_token_num = prefill_input_len * batch_size
442+ b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
443+
444+ model_input = ModelInput (
445+ batch_size = batch_size ,
446+ total_token_num = total_token_num ,
447+ max_len_in_batch = prefill_input_len ,
448+ input_ids = dummy_input_ids ,
449+ mem_indexes = mem_indexes ,
450+ b_req_idx = b_req_idx ,
451+ b_mtp_index = b_mtp_index ,
452+ b_seq_len = b_seq_len ,
453+ b_ready_cache_len = b_ready_cache_len ,
454+ is_prefill = True ,
455+ multimodal_params = [],
456+ ** self ._gen_special_model_input (total_token_num ),
457+ )
458+ return model_input
459+
394460 @final
395461 def _context_forward (self , input_ids , infer_state : InferStateInfo ):
396462 run_mode_index = 1 if self .enable_tpsp_mix_mode else 0
@@ -680,25 +746,8 @@ def _check_max_len_infer(self):
680746 # 模拟最大长度进行 prefill,观察是否出现 OOM
681747 try :
682748 logger .info ("begin check max_len infer" )
683- dummy_input_ids = torch .ones (self .batch_max_tokens , dtype = torch .int32 , device = "cuda" )
684- b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda" )
685- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )).cuda ()
686- b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
687- b_seq_len [:] = self .batch_max_tokens
688- b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
689- total_token_num = self .batch_max_tokens
690- b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
691- model_input = ModelInput (
692- batch_size = 1 ,
693- total_token_num = total_token_num ,
694- max_len_in_batch = self .batch_max_tokens ,
695- input_ids = dummy_input_ids ,
696- mem_indexes = mem_indexes ,
697- b_req_idx = b_req_idx ,
698- b_seq_len = b_seq_len ,
699- b_mtp_index = b_mtp_index ,
700- is_prefill = True ,
701- b_ready_cache_len = b_ready_cache_len ,
749+ model_input = self ._build_prefill_model_input (
750+ self .batch_max_tokens , random_token = False , include_special = False
702751 )
703752 model_output = self .forward (
704753 model_input ,
@@ -752,40 +801,21 @@ def _autotune_warmup(self):
752801 for input_len in warmup_lengths :
753802 try :
754803 logger .info (f"autotune warmup for length { input_len } " )
755- dummy_input_ids = torch .randint (0 , 10000 , (input_len ,), dtype = torch .int32 , device = "cuda" )
756- b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda" )
757- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )).cuda ()
758- b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda" )
759- b_seq_len [:] = input_len
760- b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
761- total_token_num = input_len
762- b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda" )
763- model_input = ModelInput (
764- batch_size = 1 ,
765- total_token_num = total_token_num ,
766- max_len_in_batch = input_len ,
767- input_ids = dummy_input_ids ,
768- mem_indexes = mem_indexes ,
769- b_req_idx = b_req_idx ,
770- b_seq_len = b_seq_len ,
771- b_mtp_index = b_mtp_index ,
772- is_prefill = True ,
773- b_ready_cache_len = b_ready_cache_len ,
774- multimodal_params = [],
775- ** self ._gen_special_model_input (total_token_num ),
776- )
804+ model_input = self ._build_prefill_model_input (input_len , random_token = True , include_special = True )
777805 model_output = self .forward (
778806 model_input ,
779807 )
780808 del model_output
781809 self .req_manager .free_all ()
782810 self .mem_manager .free_all ()
811+ gc .collect ()
783812 torch .cuda .empty_cache ()
784813 logger .info (f"autotune warmup for length { input_len } ok" )
785814 except Exception as e :
786815 logger .warning (f"autotune warmup for length { input_len } failed: { str (e )} " )
787816 self .req_manager .free_all ()
788817 self .mem_manager .free_all ()
818+ gc .collect ()
789819 torch .cuda .empty_cache ()
790820 self .layers_num = layer_num_bak
791821 torch .distributed .barrier ()
@@ -803,39 +833,12 @@ def _init_padded_req(self):
803833 # prefill init padding req.
804834 prefill_input_len = 1
805835 batch_size = 1
806- dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda" )
807- b_req_idx = torch .tensor (
808- [self .req_manager .HOLD_REQUEST_ID for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
809- )
810- mem_indexes = torch .tensor (
811- [self .mem_manager .HOLD_TOKEN_MEMINDEX for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
812- )
813- b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda" )
814- b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
815- total_token_num = prefill_input_len * batch_size
816- b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
817- model_input = ModelInput (
818- batch_size = batch_size ,
819- total_token_num = total_token_num ,
820- max_len_in_batch = prefill_input_len ,
821- input_ids = dummy_input_ids ,
822- mem_indexes = mem_indexes ,
823- b_req_idx = b_req_idx ,
824- b_mtp_index = b_mtp_index ,
825- b_seq_len = b_seq_len ,
826- b_ready_cache_len = b_ready_cache_len ,
827- is_prefill = True ,
828- multimodal_params = [],
829- ** self ._gen_special_model_input (total_token_num ),
836+ model_input = self ._build_padded_prefill_hold_model_input (
837+ prefill_input_len = prefill_input_len , batch_size = batch_size
830838 )
831839
832840 model_output : ModelOutput = self .forward (model_input )
833841 del model_input
834- del dummy_input_ids
835- del b_req_idx
836- del mem_indexes
837- del b_seq_len
838- del b_ready_cache_len
839842 del model_output
840843 torch .cuda .empty_cache ()
841844 return
0 commit comments