Skip to content

Commit 4b7c4a1

Browse files
author
wangzaijun
committed
fix
1 parent a53bbd4 commit 4b7c4a1

File tree

2 files changed

+10
-3
lines changed

2 files changed

+10
-3
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,13 @@ def _create_unpad_prefill_model_output(self, model_output: ModelOutput, origin_h
410410
if handle_token_num == origin_handle_token_num:
411411
return model_output
412412

413-
new_model_output = copy.copy(model_output)
414-
new_model_output.logits = new_model_output.logits[0:origin_handle_token_num]
413+
if self.return_all_prompt_logics:
414+
new_model_output = copy.copy(model_output)
415+
new_model_output.logits = new_model_output.logits[0:origin_handle_token_num]
416+
else:
417+
new_model_output = copy.copy(model_output)
418+
# 移除多余的pad 的那个 req 对应的 logics
419+
new_model_output.logits = new_model_output.logits[0:-1]
415420

416421
# 特殊模型,特殊模式的特殊变量的特殊 unpad
417422
if new_model_output.deepseekv3_mtp_main_output_hiddens is not None:

lightllm/common/basemodel/triton_kernel/gather_token_id.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def scatter_token(
6161
b_req_idx: (batch_size,)
6262
b_mtp_index: (batch_size,)
6363
"""
64-
assert next_token_ids.shape[0] == b_req_idx.shape[0]
64+
assert (
65+
next_token_ids.shape[0] == b_req_idx.shape[0]
66+
), f"batch size not match, {next_token_ids.shape[0]} != {b_req_idx.shape[0]}"
6567
batch_size = b_req_idx.shape[0]
6668
BLOCK = 256
6769

0 commit comments

Comments
 (0)