@@ -81,7 +81,7 @@ def __init__(self, kvargs):
8181 self .tp_world_size_ = get_dp_world_size ()
8282 self .enable_tpsp_mix_mode = get_env_start_args ().enable_tpsp_mix_mode
8383
84- self .is_deepseekv3_mtp_mode = self .args .mtp_mode == "deepseekv3"
84+ self .is_deepseekv3_mtp_mode = self .args .mtp_mode in [ "deepseekv3_vanilla" , "deepseekv3_eagle" ]
8585
8686 self ._init_datatype ()
8787 self ._init_config ()
@@ -262,10 +262,8 @@ def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0)
262262 infer_state .b_req_idx = model_input .b_req_idx
263263 infer_state .b_seq_len = model_input .b_seq_len
264264 if model_input .is_prefill :
265- if model_input .b_ready_cache_len is not None :
266- infer_state .b_ready_cache_len = model_input .b_ready_cache_len
267- else :
268- infer_state .b_ready_cache_len = torch .zeros_like (input = infer_state .b_seq_len )
265+ assert model_input .b_ready_cache_len is not None
266+ infer_state .b_ready_cache_len = model_input .b_ready_cache_len
269267
270268 infer_state .multimodal_params = model_input .multimodal_params
271269
@@ -337,14 +335,14 @@ def _prefill(
337335 infer_state = self ._create_inferstate (model_input )
338336 init_req_to_token_indexes (
339337 self .req_manager .req_to_token_indexs ,
340- model_input .b_req_idx ,
341- model_input .b_seq_len ,
342- infer_state . b_ready_cache_len ,
338+ model_input .b_req_idx_cpu ,
339+ model_input .b_seq_len_cpu ,
340+ model_input . b_ready_cache_len_cpu ,
343341 model_input .max_len_in_batch ,
344342 infer_state .mem_index ,
345343 )
346344
347- infer_state .init_some_extra_state (self , model_input . input_ids )
345+ infer_state .init_some_extra_state (self , model_input )
348346 return self ._context_forward (model_input .input_ids , infer_state )
349347
350348 def _decode (
@@ -369,7 +367,7 @@ def _decode(
369367 infer_state .b_seq_len ,
370368 infer_state .mem_index ,
371369 )
372- infer_state .init_some_extra_state (self , padded_model_input . input_ids )
370+ infer_state .init_some_extra_state (self , padded_model_input )
373371
374372 if self .graph .need_capture (find_graph_batch_size ):
375373 infer_state .is_cuda_graph = True
@@ -390,7 +388,7 @@ def _decode(
390388 infer_state .b_seq_len ,
391389 infer_state .mem_index ,
392390 )
393- infer_state .init_some_extra_state (self , model_input . input_ids )
391+ infer_state .init_some_extra_state (self , model_input )
394392 model_output = self ._token_forward (model_input .input_ids , infer_state )
395393
396394 return model_output
@@ -540,15 +538,15 @@ def microbatch_overlap_decode(self, model_input0: ModelInput, model_input1: Mode
540538 infer_state0 .b_seq_len ,
541539 infer_state0 .mem_index ,
542540 )
543- infer_state0 .init_some_extra_state (self , padded_model_input0 . input_ids )
541+ infer_state0 .init_some_extra_state (self , padded_model_input0 )
544542 infer_state1 = self ._create_inferstate (padded_model_input1 , 1 )
545543 copy_kv_index_to_req (
546544 self .req_manager .req_to_token_indexs ,
547545 infer_state1 .b_req_idx ,
548546 infer_state1 .b_seq_len ,
549547 infer_state1 .mem_index ,
550548 )
551- infer_state1 .init_some_extra_state (self , padded_model_input1 . input_ids )
549+ infer_state1 .init_some_extra_state (self , padded_model_input1 )
552550
553551 if self .graph .need_capture (find_graph_batch_size ):
554552 infer_state0 .is_cuda_graph = True
@@ -684,25 +682,25 @@ def _check_max_len_infer(self):
684682 # 模拟最大长度进行 prefill,观察是否出现 OOM
685683 try :
686684 logger .info ("begin check max_len infer" )
687- dummy_input_ids = torch .ones (self .batch_max_tokens , dtype = torch .int32 , device = "cuda " )
688- b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda " )
689- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )). cuda ()
690- b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda " )
685+ dummy_input_ids = torch .ones (self .batch_max_tokens , dtype = torch .int32 , device = "cpu " )
686+ b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cpu " )
687+ mem_indexes = self .mem_manager .alloc (len (dummy_input_ids ))
688+ b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cpu " )
691689 b_seq_len [:] = self .batch_max_tokens
692- b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda " )
690+ b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cpu " )
693691 total_token_num = self .batch_max_tokens
694- b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda " )
692+ b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cpu " )
695693 model_input = ModelInput (
696694 batch_size = 1 ,
697695 total_token_num = total_token_num ,
698696 max_len_in_batch = self .batch_max_tokens ,
699- input_ids = dummy_input_ids ,
700- mem_indexes = mem_indexes ,
701- b_req_idx = b_req_idx ,
702- b_seq_len = b_seq_len ,
703- b_mtp_index = b_mtp_index ,
697+ input_ids_cpu = dummy_input_ids ,
698+ mem_indexes_cpu = mem_indexes ,
699+ b_req_idx_cpu = b_req_idx ,
700+ b_seq_len_cpu = b_seq_len ,
701+ b_mtp_index_cpu = b_mtp_index ,
704702 is_prefill = True ,
705- b_ready_cache_len = b_ready_cache_len ,
703+ b_ready_cache_len_cpu = b_ready_cache_len ,
706704 )
707705 model_output = self .forward (
708706 model_input ,
@@ -750,29 +748,29 @@ def _autotune_warmup(self):
750748 self .layers_num = self .autotune_layers ()
751749 for input_len in tqdm (warmup_lengths , desc = "warming up" ):
752750 try :
753- rand_gen = torch .Generator (device = "cuda " )
751+ rand_gen = torch .Generator (device = "cpu " )
754752 rand_gen .manual_seed (input_len )
755753 dummy_input_ids = torch .randint (
756- 0 , 10000 , (input_len ,), dtype = torch .int32 , device = "cuda " , generator = rand_gen
754+ 0 , 10000 , (input_len ,), dtype = torch .int32 , device = "cpu " , generator = rand_gen
757755 )
758- b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cuda " )
759- mem_indexes = self .mem_manager .alloc (len (dummy_input_ids )). cuda ()
760- b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cuda " )
756+ b_req_idx = torch .tensor ([self .req_manager .alloc ()], dtype = torch .int32 , device = "cpu " )
757+ mem_indexes = self .mem_manager .alloc (len (dummy_input_ids ))
758+ b_seq_len = torch .ones (1 , dtype = torch .int32 , device = "cpu " )
761759 b_seq_len [:] = input_len
762- b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cuda " )
760+ b_ready_cache_len = torch .zeros (1 , dtype = torch .int32 , device = "cpu " )
763761 total_token_num = input_len
764- b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cuda " )
762+ b_mtp_index = torch .zeros (1 , dtype = torch .int32 , device = "cpu " )
765763 model_input = ModelInput (
766764 batch_size = 1 ,
767765 total_token_num = total_token_num ,
768766 max_len_in_batch = input_len ,
769- input_ids = dummy_input_ids ,
770- mem_indexes = mem_indexes ,
771- b_req_idx = b_req_idx ,
772- b_seq_len = b_seq_len ,
773- b_mtp_index = b_mtp_index ,
767+ input_ids_cpu = dummy_input_ids ,
768+ mem_indexes_cpu = mem_indexes ,
769+ b_req_idx_cpu = b_req_idx ,
770+ b_seq_len_cpu = b_seq_len ,
771+ b_mtp_index_cpu = b_mtp_index ,
774772 is_prefill = True ,
775- b_ready_cache_len = b_ready_cache_len ,
773+ b_ready_cache_len_cpu = b_ready_cache_len ,
776774 multimodal_params = [],
777775 ** self ._gen_special_model_input (total_token_num ),
778776 )
@@ -807,27 +805,27 @@ def _init_padded_req(self):
807805 # prefill init padding req.
808806 prefill_input_len = 1
809807 batch_size = 1
810- dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cuda " )
808+ dummy_input_ids = torch .ones ((batch_size ,), dtype = torch .int32 , device = "cpu " )
811809 b_req_idx = torch .tensor (
812- [self .req_manager .HOLD_REQUEST_ID for _ in range (batch_size )], dtype = torch .int32 , device = "cuda "
810+ [self .req_manager .HOLD_REQUEST_ID for _ in range (batch_size )], dtype = torch .int32 , device = "cpu "
813811 )
814812 mem_indexes = torch .tensor (
815- [self .mem_manager .HOLD_TOKEN_MEMINDEX for _ in range (batch_size )], dtype = torch .int32 , device = "cuda "
813+ [self .mem_manager .HOLD_TOKEN_MEMINDEX for _ in range (batch_size )], dtype = torch .int32 , device = "cpu "
816814 )
817- b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cuda " )
818- b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda " )
815+ b_seq_len = torch .ones (batch_size , dtype = torch .int32 , device = "cpu " )
816+ b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cpu " )
819817 total_token_num = prefill_input_len * batch_size
820- b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda " )
818+ b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cpu " )
821819 model_input = ModelInput (
822820 batch_size = batch_size ,
823821 total_token_num = total_token_num ,
824822 max_len_in_batch = prefill_input_len ,
825- input_ids = dummy_input_ids ,
826- mem_indexes = mem_indexes ,
827- b_req_idx = b_req_idx ,
828- b_mtp_index = b_mtp_index ,
829- b_seq_len = b_seq_len ,
830- b_ready_cache_len = b_ready_cache_len ,
823+ input_ids_cpu = dummy_input_ids ,
824+ mem_indexes_cpu = mem_indexes ,
825+ b_req_idx_cpu = b_req_idx ,
826+ b_mtp_index_cpu = b_mtp_index ,
827+ b_seq_len_cpu = b_seq_len ,
828+ b_ready_cache_len_cpu = b_ready_cache_len ,
831829 is_prefill = True ,
832830 multimodal_params = [],
833831 ** self ._gen_special_model_input (total_token_num ),
0 commit comments