@@ -92,7 +92,7 @@ def normal_decode(self, decode_reqs: List[InferReq], max_decode_num: int, uninit
9292 from .pre_process import padded_prepare_decode_inputs
9393
9494 kwargs , run_reqs , padded_req_num = padded_prepare_decode_inputs (
95- decode_reqs , max_decode_num , is_multimodal = False
95+ decode_reqs , max_decode_num , is_multimodal = self . is_multimodal
9696 )
9797 logits = self .model .forward (** kwargs )
9898
@@ -118,7 +118,7 @@ def overlap_decode(self, decode_reqs: List[InferReq], max_decode_num: int, unini
118118 micro_batch1 ,
119119 run_reqs1 ,
120120 padded_req_num1 ,
121- ) = padded_overlap_prepare_decode_inputs (decode_reqs , max_decode_num , is_multimodal = False )
121+ ) = padded_overlap_prepare_decode_inputs (decode_reqs , max_decode_num , is_multimodal = self . is_multimodal )
122122 logits , logits1 = self .model .microbatch_overlap_decode (micro_batch , micro_batch1 )
123123 self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
124124 req_num , req_num1 = len (run_reqs ), len (run_reqs1 )
@@ -147,7 +147,7 @@ def overlap_prefill_reqs(self, prefill_reqs: List[InferReq], max_prefill_num: in
147147 micro_batch1 ,
148148 run_reqs1 ,
149149 padded_req_num1 ,
150- ) = padded_overlap_prepare_prefill_inputs (prefill_reqs , max_prefill_num , is_multimodal = False )
150+ ) = padded_overlap_prepare_prefill_inputs (prefill_reqs , max_prefill_num , is_multimodal = self . is_multimodal )
151151 logits , logits1 = self .model .microbatch_overlap_prefill (micro_batch , micro_batch1 )
152152 self ._overlap_req_init_and_filter (uninit_reqs = uninit_reqs , ok_finished_reqs = ok_finished_reqs , clear_list = True )
153153 req_num , req_num1 = len (run_reqs ), len (run_reqs1 )
0 commit comments