Skip to content

Commit 6b80d8a

Browse files
author
wangzaijun
committed
fix multi_modal
1 parent 20d927b commit 6b80d8a

File tree

6 files changed

+18
-4
lines changed

6 files changed

+18
-4
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -829,6 +829,7 @@ def _check_max_len_infer(self):
829829
is_prefill=True,
830830
b_ready_cache_len=b_ready_cache_len,
831831
b_prefill_start_loc=b_prefill_start_loc,
832+
multimodal_params=[{"images": [], "audios": []}],
832833
)
833834
model_output = self.forward(
834835
model_input,
@@ -905,7 +906,7 @@ def _autotune_warmup(self):
905906
is_prefill=True,
906907
b_ready_cache_len=b_ready_cache_len,
907908
b_prefill_start_loc=b_prefill_start_loc,
908-
multimodal_params=[],
909+
multimodal_params=[{"images": [], "audios": []}],
909910
**self._gen_special_model_input(total_token_num),
910911
)
911912
model_output = self.forward(
@@ -968,7 +969,7 @@ def _init_padded_req(self):
968969
b_ready_cache_len=b_ready_cache_len,
969970
b_prefill_start_loc=b_prefill_start_loc,
970971
is_prefill=True,
971-
multimodal_params=[],
972+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
972973
**self._gen_special_model_input(total_token_num),
973974
)
974975

lightllm/common/basemodel/batch_objs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ def to_cuda(self):
7373
else:
7474
self.b_shared_seq_len = self.b_shared_seq_len.cuda(non_blocking=True)
7575

76+
def __post_init__(self):
77+
assert len(self.multimodal_params) == self.batch_size
78+
7679

7780
@dataclass
7881
class ModelOutput:

lightllm/common/basemodel/cuda_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ def warmup(self, model):
216216
b_seq_len=b_seq_len,
217217
b_mtp_index=b_mtp_index,
218218
is_prefill=False,
219+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
219220
**model._gen_special_model_input(batch_size),
220221
)
221222
model_output: ModelOutput = model.forward(model_input)
@@ -274,6 +275,7 @@ def warmup_overlap(self, model):
274275
mem_indexes=mem_indexes,
275276
b_req_idx=b_req_idx,
276277
b_seq_len=b_seq_len,
278+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
277279
**model._gen_special_model_input(batch_size),
278280
)
279281
decode_batches.append(micro_batch)

lightllm/common/basemodel/prefill_cuda_graph.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ def warmup(self, model):
182182
is_prefill=True,
183183
b_prefill_has_output_cpu=[False],
184184
prefix_total_token_num=0,
185+
multimodal_params=[{"images": [], "audios": []}],
185186
**model._gen_special_model_input(token_num=total_token_num),
186187
)
187188
model_output: ModelOutput = model.forward(model_input)
@@ -242,6 +243,7 @@ def warmup_overlap(self, model):
242243
is_prefill=True,
243244
b_prefill_has_output_cpu=[False],
244245
prefix_total_token_num=0,
246+
multimodal_params=[{"images": [], "audios": []}],
245247
**model._gen_special_model_input(token_num=total_token_num),
246248
)
247249

test/benchmark/static_inference/model_infer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def overlap_prefill(
9090
b_seq_len=_0_b_seq_len,
9191
is_prefill=True,
9292
b_ready_cache_len=_o_b_ready_cache_len,
93-
multimodal_params={},
93+
multimodal_params=[{"images": [], "audios": []} for _ in range(_0_batch_size)],
9494
mem_indexes_cpu=_0_mem_indexes,
9595
)
9696

@@ -114,7 +114,7 @@ def overlap_prefill(
114114
b_seq_len=_1_b_seq_len,
115115
is_prefill=True,
116116
b_ready_cache_len=_1_b_ready_cache_len,
117-
multimodal_params={},
117+
multimodal_params=[{"images": [], "audios": []} for _ in range(_1_batch_size)],
118118
mem_indexes_cpu=_1_mem_indexes,
119119
)
120120

@@ -144,6 +144,7 @@ def overlap_decode(
144144
b_mtp_index=_0_b_mtp_index,
145145
b_seq_len=_0_b_seq_len,
146146
mem_indexes_cpu=_0_mem_indexes,
147+
multimodal_params=[{"images": [], "audios": []} for _ in range(_0_batch_size)],
147148
)
148149

149150
_1_batch_size = batch_size - batch_size // 2
@@ -164,6 +165,7 @@ def overlap_decode(
164165
b_mtp_index=_1_b_mtp_index,
165166
b_seq_len=_1_b_seq_len,
166167
mem_indexes_cpu=_1_mem_indexes,
168+
multimodal_params=[{"images": [], "audios": []} for _ in range(_1_batch_size)],
167169
)
168170

169171
output, output1 = model_part.microbatch_overlap_decode(micro_batch1, micro_batch2)
@@ -202,6 +204,7 @@ def prefill(
202204
b_ready_cache_len=b_ready_cache_len, # b_ready_cache_len
203205
b_prefill_start_loc=b_prefill_start_loc,
204206
prefix_total_token_num=0, # the default kvcache len is zero.
207+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
205208
)
206209

207210
model_output = model_part.forward(model_input)
@@ -223,6 +226,7 @@ def decode(
223226
b_mtp_index=b_mtp_index,
224227
mem_indexes_cpu=mem_indexes,
225228
is_prefill=False,
229+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
226230
)
227231
model_output = model_part.forward(model_input)
228232
return model_output.logits

test/benchmark/static_inference/model_infer_mtp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
136136
b_seq_len=b_seq_len,
137137
is_prefill=True,
138138
b_ready_cache_len=b_ready_cache_len,
139+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size)],
139140
)
140141

141142
model_output: ModelOutput = main_model.forward(model_input)
@@ -202,6 +203,7 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_
202203
b_req_idx=nopad_b_seq_idx,
203204
b_seq_len=nopad_b_seq_len,
204205
is_prefill=False,
206+
multimodal_params=[{"images": [], "audios": []} for _ in range(batch_size * (len(draft_models) + 1))],
205207
)
206208

207209
# Main decode

0 commit comments

Comments
 (0)