@@ -66,6 +66,7 @@ def overlap_prefill(
6666 input_ids ,
6767 mem_indexes ,
6868 b_req_idx ,
69+ b_mtp_index ,
6970 b_seq_len ,
7071 total_token_num ,
7172 b_ready_cache_len ,
@@ -76,16 +77,18 @@ def overlap_prefill(
7677 _0_input_ids = input_ids [: total_token_num // 2 ]
7778 _0_mem_indexes = mem_indexes [: total_token_num // 2 ]
7879 _0_b_req_idx = b_req_idx [: batch_size // 2 ]
80+ _0_b_mtp_index = b_mtp_index [: batch_size // 2 ]
7981 _0_b_seq_len = b_seq_len [: batch_size // 2 ]
8082 _o_b_ready_cache_len = b_ready_cache_len [: batch_size // 2 ]
8183 micro_batch1 = ModelInput (
8284 _0_batch_size ,
8385 _0_total_token_num ,
8486 _0_max_len_in_batch ,
8587 _0_input_ids ,
86- _0_mem_indexes ,
8788 _0_b_req_idx ,
89+ _0_b_mtp_index ,
8890 _0_b_seq_len ,
91+ _0_mem_indexes ,
8992 True ,
9093 _o_b_ready_cache_len ,
9194 {},
@@ -97,6 +100,7 @@ def overlap_prefill(
97100 _1_input_ids = input_ids [total_token_num // 2 :]
98101 _1_mem_indexes = mem_indexes [total_token_num // 2 :]
99102 _1_b_req_idx = b_req_idx [batch_size // 2 :]
103+ _1_b_mtp_index = b_mtp_index [batch_size // 2 :]
100104 _1_b_seq_len = b_seq_len [batch_size // 2 :]
101105 _1_b_ready_cache_len = b_ready_cache_len [batch_size // 2 :]
102106
@@ -105,9 +109,10 @@ def overlap_prefill(
105109 _1_total_token_num ,
106110 _1_max_len_in_batch ,
107111 _1_input_ids ,
108- _1_mem_indexes ,
109112 _1_b_req_idx ,
113+ _1_b_mtp_index ,
110114 _1_b_seq_len ,
115+ _1_mem_indexes ,
111116 True ,
112117 _1_b_ready_cache_len ,
113118 {},
@@ -120,23 +125,25 @@ def overlap_prefill(
120125
121126
122127def overlap_decode (
123- model_part , batch_size , max_len_in_batch , input_ids , mem_indexes , b_req_idx , b_seq_len , total_token_num
128+ model_part , batch_size , max_len_in_batch , input_ids , mem_indexes , b_req_idx , b_mtp_index , b_seq_len , total_token_num
124129):
125130 _0_batch_size = batch_size // 2
126131 _0_total_token_num = total_token_num // 2
127132 _0_max_len_in_batch = max_len_in_batch
128133 _0_input_ids = input_ids [: batch_size // 2 ]
129134 _0_mem_indexes = mem_indexes [: batch_size // 2 ]
130135 _0_b_req_idx = b_req_idx [: batch_size // 2 ]
136+ _0_b_mtp_index = b_mtp_index [: batch_size // 2 ]
131137 _0_b_seq_len = b_seq_len [: batch_size // 2 ]
132138 micro_batch1 = ModelInput (
133139 _0_batch_size ,
134140 _0_total_token_num ,
135141 _0_max_len_in_batch ,
136142 _0_input_ids ,
137- _0_mem_indexes ,
138143 _0_b_req_idx ,
144+ _0_b_mtp_index ,
139145 _0_b_seq_len ,
146+ _0_mem_indexes ,
140147 )
141148
142149 _1_batch_size = batch_size - batch_size // 2
@@ -145,16 +152,18 @@ def overlap_decode(
145152 _1_input_ids = input_ids [batch_size // 2 :]
146153 _1_mem_indexes = mem_indexes [batch_size // 2 :]
147154 _1_b_req_idx = b_req_idx [batch_size // 2 :]
155+ _1_b_mtp_index = b_mtp_index [batch_size // 2 :]
148156 _1_b_seq_len = b_seq_len [batch_size // 2 :]
149157
150158 micro_batch2 = ModelInput (
151159 _1_batch_size ,
152160 _1_total_token_num ,
153161 _1_max_len_in_batch ,
154162 _1_input_ids ,
155- _1_mem_indexes ,
156163 _1_b_req_idx ,
164+ _1_b_mtp_index ,
157165 _1_b_seq_len ,
166+ _1_mem_indexes ,
158167 )
159168
160169 output , output1 = model_part .microbatch_overlap_decode (micro_batch1 , micro_batch2 )
@@ -170,6 +179,7 @@ def prefill(
170179 input_ids ,
171180 mem_indexes ,
172181 b_req_idx ,
182+ b_mtp_index ,
173183 b_seq_len ,
174184 total_token_num ,
175185 b_ready_cache_len ,
@@ -179,25 +189,30 @@ def prefill(
179189 total_token_num ,
180190 max_len_in_batch ,
181191 input_ids ,
182- mem_indexes ,
183192 b_req_idx ,
193+ b_mtp_index ,
184194 b_seq_len ,
195+ mem_indexes ,
185196 is_prefill = True ,
186197 b_ready_cache_len = b_ready_cache_len ,
187198 )
199+
188200 model_output = model_part .forward (model_input )
189201 return model_output .logits
190202
191203
192- def decode (model_part , batch_size , max_len_in_batch , input_ids , mem_indexes , b_req_idx , b_seq_len , total_token_num ):
204+ def decode (
205+ model_part , batch_size , max_len_in_batch , input_ids , mem_indexes , b_req_idx , b_mtp_index , b_seq_len , total_token_num
206+ ):
193207 model_input = ModelInput (
194208 batch_size ,
195209 total_token_num ,
196210 max_len_in_batch ,
197211 input_ids ,
198- mem_indexes ,
199212 b_req_idx ,
213+ b_mtp_index ,
200214 b_seq_len ,
215+ mem_indexes ,
201216 is_prefill = False ,
202217 )
203218 model_output = model_part .forward (model_input )
@@ -236,6 +251,7 @@ def run_forward_once(
236251 b_req_idx = torch .tensor (
237252 [model_part .req_manager .alloc () for _ in range (batch_size )], dtype = torch .int32 , device = "cuda"
238253 )
254+ b_mtp_index = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
239255 b_seq_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
240256 b_ready_cache_len = torch .zeros (batch_size , dtype = torch .int32 , device = "cuda" )
241257 for i in range (batch_size ):
@@ -260,6 +276,7 @@ def run_forward_once(
260276 test_data ,
261277 mem_indexes ,
262278 b_req_idx ,
279+ b_mtp_index ,
263280 b_seq_len ,
264281 total_token_num ,
265282 b_ready_cache_len , # b_ready_cache_len
@@ -288,6 +305,7 @@ def run_forward_once(
288305 test_data ,
289306 mem_indexes ,
290307 b_req_idx ,
308+ b_mtp_index ,
291309 b_seq_len ,
292310 total_token_num ,
293311 b_ready_cache_len , # b_ready_cache_len
@@ -302,6 +320,7 @@ def run_forward_once(
302320 torch .cuda .synchronize ()
303321 step_start = time .time ()
304322 total_token_num += batch_size
323+ b_mtp_index += 1
305324 b_seq_len += 1
306325 mem_indexes = model_part .req_manager .mem_manager .alloc (predict_ids .shape [0 ]).cuda ()
307326 max_len_in_batch = input_len + i + 1
@@ -312,6 +331,7 @@ def run_forward_once(
312331 predict_ids .view (- 1 ),
313332 mem_indexes ,
314333 b_req_idx ,
334+ b_mtp_index ,
315335 b_seq_len ,
316336 total_token_num ,
317337 )
@@ -325,6 +345,7 @@ def run_forward_once(
325345 predict_ids .view (- 1 ),
326346 mem_indexes ,
327347 b_req_idx ,
348+ b_mtp_index ,
328349 b_seq_len ,
329350 total_token_num ,
330351 ),
0 commit comments