@@ -235,8 +235,8 @@ def prefill_overlap(self, event_pack: OverlapEventPack, prefill_reqs: List[Infer
235235 b_has_out_cpu = (
236236 micro_input0 .b_prefill_has_output_cpu [0 :req_num0 ] + micro_input1 .b_prefill_has_output_cpu [0 :req_num1 ]
237237 )
238- b_mtp_index = torch .cat (micro_input0 .b_mtp_index [0 :req_num0 ], micro_input1 .b_mtp_index [0 :req_num1 ])
239- b_req_idx = torch .cat (micro_input0 .b_req_idx [0 :req_num0 ], micro_input1 .b_req_idx [0 :req_num1 ])
238+ b_mtp_index = torch .cat (( micro_input0 .b_mtp_index [0 :req_num0 ], micro_input1 .b_mtp_index [0 :req_num1 ]), dim = 0 )
239+ b_req_idx = torch .cat (( micro_input0 .b_req_idx [0 :req_num0 ], micro_input1 .b_req_idx [0 :req_num1 ]), dim = 0 )
240240
241241 if (req_num0 + req_num1 ) > 0 :
242242
@@ -291,7 +291,7 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
291291 micro_input1 ,
292292 run_reqs1 ,
293293 padded_req_num1 ,
294- ) = padded_overlap_prepare_decode_inputs (decode_reqs , is_multimodal = self . is_multimodal )
294+ ) = padded_overlap_prepare_decode_inputs (req_objs = decode_reqs )
295295 micro_input0 : ModelInput = micro_input0
296296 micro_input1 : ModelInput = micro_input1
297297
@@ -305,8 +305,8 @@ def decode_overlap(self, event_pack: OverlapEventPack, decode_reqs: List[InferRe
305305
306306 logits [0 :req_num0 , :].copy_ (logits0 [0 :req_num0 , :], non_blocking = True )
307307 logits [req_num0 : (req_num0 + req_num1 ), :].copy_ (logits1 [0 :req_num1 , :], non_blocking = True )
308- b_mtp_index = torch .cat (micro_input0 .b_mtp_index [0 :req_num0 ], micro_input1 .b_mtp_index [0 :req_num1 ])
309- b_req_idx = torch .cat (micro_input0 .b_req_idx [0 :req_num0 ], micro_input1 .b_req_idx [0 :req_num1 ])
308+ b_mtp_index = torch .cat (( micro_input0 .b_mtp_index [0 :req_num0 ], micro_input1 .b_mtp_index [0 :req_num1 ]), dim = 0 )
309+ b_req_idx = torch .cat (( micro_input0 .b_req_idx [0 :req_num0 ], micro_input1 .b_req_idx [0 :req_num1 ]), dim = 0 )
310310
311311 run_reqs = run_reqs0 + run_reqs1
312312 if (req_num0 + req_num1 ) > 0 :
0 commit comments