@@ -33,13 +33,17 @@ def padded_prepare_prefill_inputs(
3333 b_seq_len = []
3434 batch_multimodal_params = []
3535 b_ready_cache_len = []
36+ b_mtp_index = []
37+ b_prefill_has_output = []
38+
3639 for req in req_objs :
3740
3841 run_reqs .append (req )
3942 batch_multimodal_params .append (req .multimodal_params )
4043 b_req_idx .append (req .req_idx )
4144
4245 input_token_ids = req .get_chuncked_input_token_ids ()
46+ b_prefill_has_output .append (False if len (input_token_ids ) < req .get_cur_total_len () else True )
4347 seq_len = len (input_token_ids )
4448 input_token_len = seq_len - req .cur_kv_len
4549 input_id = input_token_ids [req .cur_kv_len :]
@@ -49,27 +53,32 @@ def padded_prepare_prefill_inputs(
4953 total_token_num += seq_len
5054 max_len_in_batch = max (max_len_in_batch , input_token_len )
5155 b_ready_cache_len .append (req .cur_kv_len )
56+ b_mtp_index .append (0 )
5257
5358 # padding fake req for prefill
5459 for _ in range (padded_req_num ):
5560 input_ids .append ([1 ])
5661 b_req_idx .append (g_infer_context .req_manager .HOLD_REQUEST_ID )
5762 b_seq_len .append (1 )
63+ b_mtp_index .append (0 )
64+ b_prefill_has_output .append (False )
5865 b_ready_cache_len .append (0 )
5966 total_token_num += 1
6067 max_len_in_batch = max (max_len_in_batch , 1 )
6168
6269 input_ids = np .concatenate (input_ids , dtype = np .int64 )
63- input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = "cuda" )
64- b_req_idx = torch .tensor (b_req_idx , dtype = torch .int32 , device = "cuda" )
65- b_seq_len = torch .tensor (b_seq_len , dtype = torch .int32 , device = "cuda" )
66- b_ready_cache_len = torch .tensor (b_ready_cache_len , dtype = torch .int32 , device = "cuda" )
70+
71+ input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = "cpu" )
72+ b_req_idx = torch .tensor (b_req_idx , dtype = torch .int32 , device = "cpu" )
73+ b_seq_len = torch .tensor (b_seq_len , dtype = torch .int32 , device = "cpu" )
74+ b_mtp_index = torch .tensor (b_mtp_index , dtype = torch .int32 , device = "cpu" )
75+ b_ready_cache_len = torch .tensor (b_ready_cache_len , dtype = torch .int32 , device = "cpu" )
6776
6877 # dynamic prompt cache 准备 token
6978 g_infer_state_lock .acquire ()
7079 if g_infer_context .radix_cache is not None :
7180 g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (input_ids .shape [0 ] - padded_req_num )
72- mem_indexes = g_infer_context .req_manager .mem_manager .alloc (input_ids .shape [0 ] - padded_req_num ). cuda ()
81+ mem_indexes = g_infer_context .req_manager .mem_manager .alloc (input_ids .shape [0 ] - padded_req_num )
7382 g_infer_state_lock .release ()
7483
7584 if padded_req_num > 0 :
@@ -85,11 +94,13 @@ def padded_prepare_prefill_inputs(
8594 total_token_num = total_token_num ,
8695 max_len_in_batch = max_len_in_batch ,
8796 input_ids = input_ids ,
88- mem_indexes = mem_indexes ,
97+ mem_indexes_cpu = mem_indexes ,
8998 b_req_idx = b_req_idx ,
99+ b_mtp_index = b_mtp_index ,
90100 b_seq_len = b_seq_len ,
91101 b_ready_cache_len = b_ready_cache_len ,
92102 is_prefill = True ,
103+ b_prefill_has_output_cpu = b_prefill_has_output ,
93104 )
94105 if is_multimodal :
95106 model_input .multimodal_params = batch_multimodal_params
@@ -98,64 +109,62 @@ def padded_prepare_prefill_inputs(
98109
99110
100111def padded_prepare_decode_inputs (
101- req_objs : List [InferReq ], dest_batch_size : Optional [int ] = None , is_multimodal = False
112+ req_objs : List [InferReq ], dest_batch_size : Optional [int ] = None
102113) -> Tuple [ModelInput , List [InferReq ], int ]:
114+ mtp_step_num = get_env_start_args ().mtp_step
103115 run_reqs = []
104116 total_token_num = 0
105117 max_len_in_batch = 0
106- input_ids = []
107118 b_req_idx = []
119+ b_mtp_index = []
108120 b_seq_len = []
109-
110121 for req in req_objs :
111122 run_reqs .append (req )
112123 b_req_idx .append (req .req_idx )
113- input_token_ids = req .get_input_token_ids ()
114- input_id = input_token_ids [- 1 ]
115- seq_len = len (input_token_ids )
124+ seq_len = req .get_cur_total_len ()
116125 assert req .cur_kv_len == seq_len - 1
117126 b_seq_len .append (seq_len )
118- input_ids .append (input_id )
119127 total_token_num += seq_len
120128 max_len_in_batch = max (max_len_in_batch , seq_len )
129+ b_mtp_index .append (0 )
121130 # process the draft tokens.
122- for step in range (len ( req .mtp_gen_token_ids ) ):
131+ for step in range (req .mtp_step ):
123132 run_reqs .append (req )
124133 b_req_idx .append (req .req_idx )
125134 seq_len += 1
126135 b_seq_len .append (seq_len )
127- input_ids .append (req .mtp_gen_token_ids [step ])
128136 total_token_num += seq_len
129137 max_len_in_batch = max (max_len_in_batch , seq_len )
138+ b_mtp_index .append (step + 1 )
130139
131140 if dest_batch_size is None :
132141 if len (run_reqs ) == 0 :
133142 dest_batch_size = 1
134143 else :
135- dest_batch_size = len (run_reqs )
144+ dest_batch_size = len (run_reqs ) * ( 1 + mtp_step_num )
136145 else :
137- assert len (run_reqs ) <= dest_batch_size
146+ assert len (run_reqs ) * ( 1 + mtp_step_num ) <= dest_batch_size
138147
139- padded_req_num = dest_batch_size - len (run_reqs )
148+ padded_req_num = dest_batch_size - len (run_reqs ) * ( 1 + mtp_step_num )
140149
141150 # padding fake req for decode
142151 for _ in range (padded_req_num ):
143- input_ids .append (1 )
144152 seq_len = 2
145153 b_req_idx .append (g_infer_context .req_manager .HOLD_REQUEST_ID )
146154 b_seq_len .append (seq_len )
155+ b_mtp_index .append (0 )
147156 total_token_num += seq_len
148157 max_len_in_batch = max (max_len_in_batch , seq_len )
149158
150- input_ids = torch .tensor (input_ids , dtype = torch .int64 , device = "cuda " )
151- b_req_idx = torch .tensor (b_req_idx , dtype = torch .int32 , device = "cuda " )
152- b_seq_len = torch .tensor (b_seq_len , dtype = torch .int32 , device = "cuda " )
159+ b_req_idx = torch .tensor (b_req_idx , dtype = torch .int32 , device = "cpu " )
160+ b_seq_len = torch .tensor (b_seq_len , dtype = torch .int32 , device = "cpu " )
161+ b_mtp_index = torch .tensor (b_mtp_index , dtype = torch .int32 , device = "cpu " )
153162
154163 # dynamic prompt cache 准备 token
155164 g_infer_state_lock .acquire ()
156165 if g_infer_context .radix_cache is not None :
157- g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (input_ids .shape [0 ] - padded_req_num )
158- mem_indexes = g_infer_context .req_manager .mem_manager .alloc (input_ids .shape [0 ] - padded_req_num ). cuda ( )
166+ g_infer_context .radix_cache .free_radix_cache_to_get_enough_token (b_seq_len .shape [0 ] - padded_req_num )
167+ mem_indexes = g_infer_context .req_manager .mem_manager .alloc (b_seq_len .shape [0 ] - padded_req_num )
159168 g_infer_state_lock .release ()
160169
161170 if padded_req_num > 0 :
@@ -170,36 +179,35 @@ def padded_prepare_decode_inputs(
170179 batch_size = b_seq_len .shape [0 ],
171180 total_token_num = total_token_num ,
172181 max_len_in_batch = max_len_in_batch ,
173- input_ids = input_ids ,
174- mem_indexes = mem_indexes ,
182+ input_ids = None ,
183+ mem_indexes_cpu = mem_indexes ,
175184 b_req_idx = b_req_idx ,
185+ b_mtp_index = b_mtp_index ,
176186 b_seq_len = b_seq_len ,
177187 is_prefill = False ,
178188 )
179189 return model_input , run_reqs , padded_req_num
180190
181191
182- def padded_overlap_prepare_decode_inputs (req_objs : List [InferReq ], is_multimodal = False ):
192+ def padded_overlap_prepare_decode_inputs (req_objs : List [InferReq ]):
183193 split_req_bound = triton .cdiv (len (req_objs ), 2 )
184194 req_objs_0 = req_objs [0 :split_req_bound ]
185195 req_objs_1 = req_objs [split_req_bound :]
186196
187197 enable_mtp = get_env_start_args ().mtp_mode is not None
188198 if enable_mtp :
189199 micro_batch_size = max (
190- sum ([len ( req .mtp_gen_token_ids ) + 1 for req in req_objs_0 ]),
191- sum ([len ( req .mtp_gen_token_ids ) + 1 for req in req_objs_1 ]),
200+ sum ([req .mtp_step + 1 for req in req_objs_0 ]),
201+ sum ([req .mtp_step + 1 for req in req_objs_1 ]),
192202 )
193203 else :
194204 micro_batch_size = triton .cdiv (len (req_objs ), 2 )
195205
196206 micro_batch_size = max (1 , micro_batch_size )
197207
198- micro_input , run_reqs , padded_req_num = padded_prepare_decode_inputs (
199- req_objs_0 , dest_batch_size = micro_batch_size , is_multimodal = is_multimodal
200- )
208+ micro_input , run_reqs , padded_req_num = padded_prepare_decode_inputs (req_objs_0 , dest_batch_size = micro_batch_size )
201209 micro_input1 , run_reqs1 , padded_req_num1 = padded_prepare_decode_inputs (
202- req_objs_1 , dest_batch_size = micro_batch_size , is_multimodal = is_multimodal
210+ req_objs_1 , dest_batch_size = micro_batch_size
203211 )
204212 return micro_input , run_reqs , padded_req_num , micro_input1 , run_reqs1 , padded_req_num1
205213
0 commit comments