@@ -33,8 +33,6 @@ def update_draft_token_mem_indexes(draft_token_memindex_map, run_reqs, mem_index
3333class ContinuesBatchWithMTPBackend (ModeBackend ):
3434 def __init__ (self ) -> None :
3535 super ().__init__ ()
36- self .accepted_cnt = 0
37- self .all_cnt = 0
3836
3937 # 支持双模型
4038 def init_model (self , kvargs ):
@@ -83,8 +81,6 @@ def init_model(self, kvargs):
8381 self .mtp_draft_token_memindex_map = torch .full (
8482 (max_req_num ,), fill_value = IS_NONE , dtype = torch .int32 , device = "cpu"
8583 )
86- self .draft_accept_count = torch .zeros ((max_req_num ,), dtype = torch .int32 , device = "cpu" )
87- self .main_step = 0
8884
8985 def prefill (self , reqs : List [Tuple ]):
9086 self ._init_reqs (reqs , init_req_obj = False )
@@ -103,8 +99,6 @@ def decode(self):
10399 prefill_reqs , is_chuncked_mode = False , is_multimodal = self .is_multimodal
104100 )
105101 model_output = self .model .forward (model_input )
106- self .main_step += 1
107- device0_print (f"main_step: { self .main_step } " )
108102
109103 self ._overlap_req_init_and_filter (
110104 uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True
@@ -134,9 +128,6 @@ def decode(self):
134128 model_output = self .model .forward (model_input )
135129 assert model_output .logits .shape [0 ] % 2 == 0
136130
137- self .main_step += 1
138- device0_print (f"main_step: { self .main_step } " )
139-
140131 self ._overlap_req_init_and_filter (
141132 uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True
142133 )
@@ -165,7 +156,6 @@ def decode(self):
165156 is_chuncked_mode = False ,
166157 do_filter_finished_reqs = False ,
167158 )
168-
169159 # spec decode: MTP
170160 draft_model_input = copy .deepcopy (model_input )
171161 draft_model_input .input_ids = torch .tensor (next_token_ids , dtype = torch .int64 , device = "cuda" )
@@ -191,8 +181,6 @@ def verify(self, next_token_ids0, run_reqs):
191181 if self .draft_token_id_map [req .req_idx ] == next_token_ids0 [i ]:
192182 accepted_reqs .append (req )
193183 accepted_index .append (i )
194- self .draft_accept_count [req .req_idx ] += 1
195- device0_print (f"draft_accept_count: { self .draft_accept_count [req .req_idx ]} " )
196184 self .main_draft_token_memindex_map [req .req_idx ] = IS_NONE
197185 else :
198186 need_free_mem_indexes .append (self .main_draft_token_memindex_map [req .req_idx ])
0 commit comments