Skip to content

Commit eac2b60

Browse files
author
wangzaijun
committed
fix
1 parent 7a470d5 commit eac2b60

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,12 +68,15 @@ def __init__(self, kvargs):
6868
assert not (self.is_token_healing and self.return_all_prompt_logics), "can not be true in same time"
6969
self.data_type = kvargs.get("data_type", "float16")
7070
mtp_step = get_env_start_args().mtp_step
71-
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16) * (mtp_step + 1)
71+
self.graph_max_batch_size = kvargs.get("graph_max_batch_size", 16)
7272
self.graph_max_batch_size = (
7373
self.graph_max_batch_size // 2
7474
if get_env_start_args().enable_decode_microbatch_overlap
7575
else self.graph_max_batch_size
7676
)
77+
# mtp 模式下需要修缮对应的最大batch size,为 (mtp_step + 1) 的倍数
78+
self.graph_max_batch_size = self.graph_max_batch_size * (mtp_step + 1)
79+
7780
self.graph_max_len_in_batch = kvargs.get("graph_max_len_in_batch", 8192)
7881
self.disable_cudagraph = kvargs.get("disable_cudagraph", False)
7982
self.quant_type = kvargs.get("quant_type", "none")

0 commit comments

Comments
 (0)