Skip to content

Commit 5c5b1d8

Browse files
author
wangzaijun
committed
fix prefill input padded error
1 parent 4b7c4a1 commit 5c5b1d8

File tree

1 file changed

+18
-17
lines changed

1 file changed

+18
-17
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -348,12 +348,10 @@ def _create_padded_decode_model_input(self, model_input: ModelInput, new_batch_s
348348
return new_model_input
349349

350350
def _create_padded_prefill_model_input(self, model_input: ModelInput, new_handle_token_num: int):
351-
if model_input.total_token_num - model_input.prefix_total_token_num == new_handle_token_num:
352-
return model_input
353-
354351
assert model_input.total_token_num - model_input.prefix_total_token_num < new_handle_token_num
355352

356353
padded_token_num = new_handle_token_num - (model_input.total_token_num - model_input.prefix_total_token_num)
354+
assert padded_token_num > 0
357355
new_model_input = copy.copy(model_input)
358356
new_model_input.batch_size = model_input.batch_size + 1
359357
new_model_input.total_token_num += padded_token_num
@@ -405,16 +403,12 @@ def _create_unpad_decode_model_output(self, model_output: ModelOutput, origin_ba
405403

406404
return new_model_output
407405

408-
def _create_unpad_prefill_model_output(self, model_output: ModelOutput, origin_handle_token_num: int):
409-
handle_token_num = model_output.logits.shape[0]
410-
if handle_token_num == origin_handle_token_num:
411-
return model_output
412-
406+
def _create_unpad_prefill_model_output(self, padded_model_output: ModelOutput, origin_handle_token_num: int):
413407
if self.return_all_prompt_logics:
414-
new_model_output = copy.copy(model_output)
408+
new_model_output = copy.copy(padded_model_output)
415409
new_model_output.logits = new_model_output.logits[0:origin_handle_token_num]
416410
else:
417-
new_model_output = copy.copy(model_output)
411+
new_model_output = copy.copy(padded_model_output)
418412
# 移除多余的pad 的那个 req 对应的 logics
419413
new_model_output.logits = new_model_output.logits[0:-1]
420414

@@ -429,14 +423,18 @@ def _prefill(
429423
self,
430424
model_input: ModelInput,
431425
):
432-
handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num
433-
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=handle_token_num):
426+
origin_handle_token_num = model_input.total_token_num - model_input.prefix_total_token_num
427+
428+
is_padded_model_input = False
429+
if self.prefill_graph is not None and self.prefill_graph.can_run(handle_token_num=origin_handle_token_num):
434430
finded_handle_token_num = self.prefill_graph.find_closest_graph_handle_token_num(
435-
handle_token_num=handle_token_num
436-
)
437-
model_input = self._create_padded_prefill_model_input(
438-
model_input=model_input, new_handle_token_num=finded_handle_token_num
431+
handle_token_num=origin_handle_token_num
439432
)
433+
if finded_handle_token_num != origin_handle_token_num:
434+
is_padded_model_input = True
435+
model_input = self._create_padded_prefill_model_input(
436+
model_input=model_input, new_handle_token_num=finded_handle_token_num
437+
)
440438

441439
infer_state = self._create_inferstate(model_input)
442440
init_req_to_token_indexes(
@@ -453,7 +451,10 @@ def _prefill(
453451

454452
infer_state.init_some_extra_state(self, model_input.input_ids)
455453
model_output = self._context_forward(model_input.input_ids, infer_state)
456-
model_output = self._create_unpad_prefill_model_output(model_output, origin_handle_token_num=handle_token_num)
454+
if is_padded_model_input:
455+
model_output = self._create_unpad_prefill_model_output(
456+
model_output, origin_handle_token_num=origin_handle_token_num
457+
)
457458
model_output.prefill_mem_indexes_ready_event = prefill_mem_indexes_ready_event
458459
return model_output
459460

0 commit comments

Comments
 (0)