Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="""aggressive schedule can lead to frequent prefill interruptions during decode.
disabling it allows the router_max_wait_tokens parameter to work more effectively.""",
)
parser.add_argument(
"--dp_prefill_wait_step",
type=int,
default=0,
help="""dp_prefill_wait_step is used to control the pacing of dp chunked prefill mode, aiming to reduce
computational waste during prefill. However, higher values can negatively impact the
first token latency. It is generally recommended to set this value between 0 and 12.""",
)

parser.add_argument(
"--use_dynamic_prompt_cache", action="store_true", help="This argument is deprecated and no longer in use."
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/core/objs/out_token_circlequeue.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Tuple

LIGHTLLM_TOKEN_MAX_BYTES = int(os.getenv("LIGHTLLM_TOKEN_MAX_BYTES", 128))
LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 6))
LIGHTLLM_OUT_TOKEN_QUEUE_SIZE = int(os.getenv("LIGHTLLM_OUT_TOKEN_QUEUE_SIZE", 8))


class QueueItem(ctypes.Structure):
Expand Down
9 changes: 8 additions & 1 deletion lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,14 @@ class ChunkedPrefillReq(Req):

def get_tuple_tokens(self, is_busy, router_max_new_token_len):
args = get_env_start_args()
max_waiting_token = args.router_max_wait_tokens
# chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于
# 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存
# 占用量的过程中,分chuncked 需要考虑其因为分 chuncked带来的生命期的延长,具体
# 体现就是在 b_len 的计算中,xxx * (max_waiting_token + 1) 的部分,这部分
# 就是通过模拟加长其输出token长度,来延长其在估计阶段的生命周期。max_waiting_token
# 的计算是保守的,每次chuncked prefill 延迟的最大步数为两种模式之合,因为
# 这个并不会导致预估的token占用量大幅增加,所以可以放心使用。
Comment on lines +286 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a very long comment. Consider breaking it up into multiple shorter comments, or rephrasing it to be more concise. Long comments can be difficult to read and understand.

max_waiting_token = args.router_max_wait_tokens + args.dp_prefill_wait_step
has_out_len = self.shm_cur_output_len
if self.sample_params.ignore_eos:
cur_max_new_token_len = self.sample_params.max_new_tokens
Expand Down
1 change: 1 addition & 0 deletions lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class StartArgs:
router_token_ratio: float = field(default=0.0)
router_max_new_token_len: int = field(default=1024)
router_max_wait_tokens: int = field(default=6)
dp_prefill_wait_step: int = field(default=0)
disable_aggressive_schedule: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
chunked_prefill_size: int = field(default=8192)
Expand Down
62 changes: 37 additions & 25 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .async_queue import AsyncQueue
from lightllm.server.core.objs import Req, FinishStatus
from lightllm.server.core.objs import SamplingParams
from lightllm.server.core.objs.out_token_circlequeue import LIGHTLLM_OUT_TOKEN_QUEUE_SIZE
from lightllm.server.core.objs.io_objs import GroupReqObjs
from lightllm.server.core.objs.shm_req_manager import ShmReqManager
from lightllm.server.core.objs.atomic_array_lock import AtomicShmArrayLock, AsyncLock, AtomicLockItem
Expand Down Expand Up @@ -281,8 +282,12 @@ async def generate(
alloced_req_indexes = []
while len(alloced_req_indexes) < sampling_params.n:
alloc_req_index = await self.shm_req_manager.async_alloc_req_index()
sleep_time = 0.1
while alloc_req_index is None:
await asyncio.sleep(0.1)
await asyncio.sleep(sleep_time)
sleep_time *= 1.1
sleep_time = min(1, sleep_time)

alloc_req_index = await self.shm_req_manager.async_alloc_req_index()
alloced_req_indexes.append(alloc_req_index)
req_objs = []
Expand Down Expand Up @@ -648,31 +653,38 @@ async def handle_loop(self):
token_list = []
for req in req_status.group_req_objs.shm_req_objs:
req_id = req.request_id
if not req.out_tokens_queue.is_empty():

text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
metadata = {
"id": int(req.shm_prompt_ids.arr[src_index]),
"logprob": float(req.shm_logprobs.arr[src_index]),
"cumlogprob": float(req.cumlogprob) / count_output_tokens,
"special": special,
"count_output_tokens": count_output_tokens,
"prompt_cache_len": req.prompt_cache_len,
"mtp_accepted_token_num": req.mtp_accepted_token_num,
}
if self.args.return_all_prompt_logprobs:
metadata.update(req.get_all_prompt_metadata())
if self.args.use_reward_model:
metadata["score"] = float(req.reward_score)

req.out_tokens_queue.pop_no_ret()

if req.finish_token_index != src_index:
token_list.append((req_id, text, metadata, FinishStatus()))
read_token_count = 1
if req.out_tokens_queue.is_full():
read_token_count = LIGHTLLM_OUT_TOKEN_QUEUE_SIZE

for _ in range(read_token_count):
if not req.out_tokens_queue.is_empty():

text, src_index, special, count_output_tokens = req.out_tokens_queue.peek()
req.cumlogprob += float(req.shm_logprobs.arr[src_index])
metadata = {
"id": int(req.shm_prompt_ids.arr[src_index]),
"logprob": float(req.shm_logprobs.arr[src_index]),
"cumlogprob": float(req.cumlogprob) / count_output_tokens,
"special": special,
"count_output_tokens": count_output_tokens,
"prompt_cache_len": req.prompt_cache_len,
"mtp_accepted_token_num": req.mtp_accepted_token_num,
}
if self.args.return_all_prompt_logprobs:
metadata.update(req.get_all_prompt_metadata())
if self.args.use_reward_model:
metadata["score"] = float(req.reward_score)

req.out_tokens_queue.pop_no_ret()

if req.finish_token_index != src_index:
token_list.append((req_id, text, metadata, FinishStatus()))
else:
finish_status = FinishStatus(req.finish_status.status)
token_list.append((req_id, text, metadata, finish_status))
else:
finish_status = FinishStatus(req.finish_status.status)
token_list.append((req_id, text, metadata, finish_status))
break

async with req_status.lock:
req_status.out_token_info_list.extend(token_list)
Expand Down
12 changes: 5 additions & 7 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
# 调度和推理进行折叠使用的线程池
self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.schedule_task = None
self.overlap_event = threading.Event()
return

async def wait_to_model_ready(self):
Expand Down Expand Up @@ -288,8 +287,12 @@ async def loop_for_fwd(

async def get_schedule_result(self, running_batch: Batch):
if self.schedule_task is None:
_start_time = time.time()

def get_new_batch():
if time.time() - _start_time < 0.001:
time.sleep(0.003)

limit_router_queue_length = None
if self.is_multinode_tp:
# 使用 all_reduce 获取最小值
Expand All @@ -300,9 +303,6 @@ def get_new_batch():
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
limit_router_queue_length = limit_router_queue_length_tensor.item()

self.overlap_event.wait(timeout=0.020)
self.overlap_event.clear()
time.sleep(0.003)
new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length)
return new_batch

Expand All @@ -320,7 +320,7 @@ async def _step(self):
# 删除所有已经 finished 的 req
# 当前无运行请求时
if self.running_batch is None:
new_batch = await self.get_schedule_result(self.running_batch)
new_batch: Batch = await self.get_schedule_result(self.running_batch)
if new_batch is not None:
self.metric_client.histogram_observe("lightllm_batch_next_size", len(new_batch.reqs))
for req in new_batch.reqs:
Expand Down Expand Up @@ -383,7 +383,6 @@ async def _prefill_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill")
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
self.overlap_event.set()
await self.model_rpc_client.prefill(reqs)
batch.filter_out_finished_req(self.shm_req_manager)
self._send_detokenization_pack()
Expand All @@ -397,7 +396,6 @@ async def _prefill_batch(self, batch: Batch):
async def _decode_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "decode")
self.overlap_event.set()
await self.model_rpc_client.decode()
# 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
if batch is not None:
Expand Down
45 changes: 37 additions & 8 deletions lightllm/server/router/model_infer/mode_backend/base_backend.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@
import os
import asyncio
import numpy as np
import rpyc
import torch
import socket
from datetime import timedelta
from typing import Dict, List, Tuple, Callable, Optional
from typing import List, Tuple, Callable, Optional
from transformers.configuration_utils import PretrainedConfig
from lightllm.utils.infer_utils import set_random_seed
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
from lightllm.utils.log_utils import init_logger
from lightllm.models import get_model
from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache
from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams
from lightllm.server.router.model_infer.infer_batch import InferReq
from lightllm.server.router.token_load import TokenLoad
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
from lightllm.common.basemodel.basemodel import TpPartBaseModel
Expand All @@ -26,12 +21,18 @@
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
from lightllm.utils.dist_utils import get_dp_rank_in_node
from lightllm.distributed import dist_group_manager
from .chuncked_prefill_state import ChunkedPrefillState
import torch.distributed as dist


class ModeBackend:
def __init__(self) -> None:
self.shm_req_manager = ShmReqManager()

# 当子类处于chuncked prefill 相关模式时,会使用该管理变量进行一些 chuncked prefill
# 的推理控制,具体使用方式可以参考 ChunkedPrefillBackend 类中的使用方式。如果是非
# chuncked prefill 相关的模式,该状态变量不会生效。
self.chunked_prefill_state = ChunkedPrefillState()
pass

def init_model(self, kvargs):
Expand All @@ -56,7 +57,6 @@ def init_model(self, kvargs):
self.eos_id: List[int] = kvargs.get("eos_id", [2])
self.disable_cudagraph = self.args.disable_cudagraph

self.cache = {}
self.logger = init_logger(__name__)

self.weight_dir = kvargs["weight_dir"]
Expand Down Expand Up @@ -140,6 +140,14 @@ def init_model(self, kvargs):
vocab_size=self.model.vocab_size,
)

# 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到
if self.dp_size > 1:
self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
self.dp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
self.dp_all_gather_tensor = torch.tensor(
[0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False
)

self.init_custom()
return

Expand Down Expand Up @@ -446,6 +454,27 @@ def _update_reqs_mtp_gen_token_ids(self, reqs: List[InferReq], mtp_draft_next_to
req.mtp_gen_token_ids.append(token_id)
return

def _dp_all_gather_prefill_req_num(self, prefill_reqs: List[InferReq]) -> Tuple[np.ndarray, int]:
"""
Gather the number of prefill requests across all DP ranks.
"""
current_dp_prefill_num = len(prefill_reqs)
self.dp_gather_item_tensor.fill_(current_dp_prefill_num)
dist.all_gather_into_tensor(self.dp_all_gather_tensor, self.dp_gather_item_tensor, group=None, async_op=False)
dp_prefill_req_nums = self.dp_all_gather_tensor.cpu().numpy()
max_prefill_num = np.max(dp_prefill_req_nums)
return dp_prefill_req_nums, max_prefill_num
Comment on lines 454 to +466
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding error handling for the dist.all_gather_into_tensor operation. If this operation fails, it could lead to inconsistent data across DP ranks.


def _dp_all_reduce_decode_req_num(self, decode_reqs: List[InferReq]) -> int:
"""
Reduce the number of decode requests across all DP ranks.
"""
current_dp_decode_num = len(decode_reqs)
self.dp_reduce_tensor.fill_(current_dp_decode_num)
dist.all_reduce(self.dp_reduce_tensor, op=dist.ReduceOp.MAX, group=None, async_op=False)
max_decode_num = self.dp_reduce_tensor.item()
Comment on lines +472 to +475
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding error handling for the dist.all_reduce operation. If this operation fails, it could lead to inconsistent data across DP ranks.

return max_decode_num

def preload_prompt_cache_kv_buffer(self, model_cfg):
self.logger.info("Preload prompt cache kv buffer.")
cur_rank = dist.get_rank()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import dataclasses
import numpy as np
from typing import List
from lightllm.utils.envs_utils import get_env_start_args
from ..infer_batch import InferReq
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


@dataclasses.dataclass
class ChunkedPrefillState:
"""
用于保存和控制 chuncked prefill 推理调度的控制状态,因为不同的场景,对于首字和包间的
诉求不是特别一致,所以需要一个状态来控制。主要通过 args.router_max_wait_tokens 参数
可以调节 prefill 的激进程度和方式,来协调首字和包间的平衡。
"""

prefill_wait_step: int = 0
need_prefill_count: int = 0
current_wait_step: int = 0

# dp chuncked prefill 的等待步数参数
dp_prefill_wait_step: int = 0
dp_current_wait_step: int = 0

# world_size
_global_world_size: int = 0

def __post_init__(self):
args = get_env_start_args()
self.prefill_wait_step = args.router_max_wait_tokens
self.dp_prefill_wait_step = args.dp_prefill_wait_step
self._global_world_size = args.tp
return

def need_prefill(self, prefill_reqs: List[InferReq], decode_reqs: List[InferReq]) -> bool:
no_decode_reqs = len(decode_reqs) == 0
step_ok = self.current_wait_step >= self.prefill_wait_step
need_prefill = self.need_prefill_count > 0

if no_decode_reqs or step_ok or need_prefill:
if need_prefill:
self.need_prefill_count -= 1

self.current_wait_step = 0
if prefill_reqs:
return True
else:
return False
else:
if prefill_reqs:
self.current_wait_step += 1
return False

def dp_need_prefill(
self,
prefill_reqs: List[InferReq],
decode_reqs: List[InferReq],
dp_prefill_req_nums: np.ndarray,
dp_max_prefill_num: int,
) -> bool:
"""
dp_need_prefill 接口用于控制 DP 模式下进行chuncked prefill时,需要考虑各个DP的真实运行请求数量:
考虑 8 个 dp 的场景,如果每个 dp 执行 prefill 的请求的数量分别为: [1, 1, 0, 0, 0, 0, 0, 0], 则在运行
的过程中,请求数量为0的dp会pad一个fake req来参与计算,但是这会导致这些dp因为一些通信同步的原因,造成大量
算力浪费,实际有效率很低。
解决方法:
在判断是否可以进行 prefill 的时候,需要先考虑所有dp的请求数量是否均衡,浪费率是否在可以接受的范围,如果无法
接受这么高的浪费率,则可以延迟 prefill 的执行时机,直到所有dp的浪费率较低时再进行prefill, 不过延迟执行的极限
等待时间,受到 dp_prefill_wait_step 参数的控制。
"""
assert dp_prefill_req_nums.shape[0] == self._global_world_size

use_ratio = np.count_nonzero(dp_prefill_req_nums) / dp_prefill_req_nums.shape[0]
step_ok = self.dp_current_wait_step >= self.dp_prefill_wait_step

if dp_max_prefill_num > 0 and (use_ratio > 0.6 or step_ok):
if use_ratio < 0.2:
self.dp_current_wait_step = 0
logger.info(f"dp chuncked prefill effective GPU Utilization Rate {use_ratio}")

return True
else:
if dp_max_prefill_num > 0:
self.dp_current_wait_step += 1
else:
self.dp_current_wait_step = 0
return False
Loading