Skip to content
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
hooks:
- id: flake8
additional_dependencies: [flake8-typing-imports==1.9.0]
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606']
args: ['--config=.flake8', '--max-line-length=120', '--ignore=TYP001, E722, C901, E203, E266, E402, E302, E241, E902, E731, F403, E701, F405, F401, W292, W293, W503, W606, E231']
18 changes: 12 additions & 6 deletions lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,12 +67,6 @@ def is_clear(self):
return len(self.reqs) == 0

def merge(self, mini_batch: "Batch"):
for _req in mini_batch.reqs:
self.reqs.append(_req)
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return

def dp_merge(self, mini_batch: "Batch"):
if mini_batch is None:
return

Expand All @@ -81,6 +75,18 @@ def dp_merge(self, mini_batch: "Batch"):
self.id_to_reqs = {req.request_id: req for req in self.reqs}
return

@staticmethod
def merge_two_batch(batch1: "Batch", batch2: "Batch") -> "Batch":
if batch1 is None and batch2 is None:
return None

not_none_batch = batch1 if batch1 is not None else batch2

merge_batch = Batch(-1, [], not_none_batch.dp_size_in_node)
merge_batch.merge(batch1)
merge_batch.merge(batch2)
return merge_batch

def __repr__(self):
return f"batch_id={self.batch_id}, " f"reqs={self.reqs}, "

Expand Down
211 changes: 89 additions & 122 deletions lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,20 @@
import copy
import time
import uuid
import uvloop
import asyncio
import torch
import rpyc
import pickle
import threading
import inspect

asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
import concurrent.futures
import zmq
import zmq.asyncio
import torch.multiprocessing as mp
import torch.distributed as dist
import multiprocessing
from typing import Dict, List, Optional
from .batch import Batch
from .batch import Batch, Req
from .model_infer.model_rpc import start_model_process, ModelRpcClient
from .req_queue import build_req_queue
from lightllm.utils.infer_utils import calculate_time
from lightllm.server.core.objs.io_objs import GroupReqIndexes
from lightllm.server.core.objs import ShmReqManager, StartArgs
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
Expand Down Expand Up @@ -79,7 +73,7 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
self.eos_id = args.eos_id
self.has_wait_tokens = 0
self.max_wait_tokens = args.router_max_wait_tokens
context = zmq.asyncio.Context(2)
context = zmq.Context(2)
self.recv_from_httpserver = context.socket(zmq.PULL)
self.recv_from_httpserver.bind(f"{args.zmq_mode}127.0.0.1:{router_port}")

Expand All @@ -106,13 +100,13 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
# 主要是为了防止调度失误,造成 OOM 等错误
self.router_lock = mp.Lock()
g_router_lock.obj = self.router_lock

# 调度和推理进行折叠使用的线程池
self.overlap_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=1)
self.schedule_task = None
return

async def wait_to_model_ready(self):
# 调度使用的对象
self.schedule_new_batch: Batch = None
self.schedule_event = asyncio.Event()

# 初始化模型
self.model_rpc_servers = []
# 用于 kv move 管理进程 和 推理进程进行task信息的交互。
Expand Down Expand Up @@ -140,8 +134,6 @@ async def wait_to_model_ready(self):
self.model_rpc_servers.append(rpc_model)

self.model_rpc_client = ModelRpcClient(
model_infer_servers=self.model_rpc_servers,
world_size=self.world_size,
rpc_event=self.rpc_event,
rpc_finished_event=self.rpc_finished_event,
)
Expand Down Expand Up @@ -223,7 +215,6 @@ def add_req(self, group_req_indexes: GroupReqIndexes):
logger.info(f"router recive req id {req.request_id} cost time {time.time() - req.start_time} s")
self.req_queue.extend(req_group)
self.send_to_detokenization.send_pyobj(group_req_indexes, protocol=pickle.HIGHEST_PROTOCOL)

return

async def loop_for_fwd(
Expand Down Expand Up @@ -285,81 +276,39 @@ async def loop_for_fwd(
if self.running_batch is None:
await asyncio.sleep(0.01) # 10ms

async def get_schedule_result(self, running_batch: Batch):
if self.schedule_task is None:
_start_time = time.time()

def get_new_batch():
if time.time() - _start_time < 0.001:
time.sleep(0.003)

limit_router_queue_length = None
if self.is_multinode_tp:
# 使用 all_reduce 获取最小值
limit_router_queue_length = len(self.req_queue.waiting_req_list)
limit_router_queue_length_tensor = torch.tensor(
limit_router_queue_length, dtype=torch.int32, device="cpu"
)
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
limit_router_queue_length = limit_router_queue_length_tensor.item()

new_batch = self.req_queue.generate_new_batch(running_batch, limit_router_queue_length)
return new_batch

self.schedule_task = asyncio.get_running_loop().run_in_executor(self.overlap_thread_pool, get_new_batch)
return None
else:
result = await self.schedule_task
self.schedule_task = None
return result
def generate_new_batch(self):
limit_router_queue_length = None
if self.is_multinode_tp:
# 使用 all_reduce 获取最小值
limit_router_queue_length = len(self.req_queue.waiting_req_list)
limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu")
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
limit_router_queue_length = limit_router_queue_length_tensor.item()

# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
new_batch = self.req_queue.generate_new_batch(
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch), limit_router_queue_length
)
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
return

async def _step(self):
"""
事件处理循环
"""
# 删除所有已经 finished 的 req
# 当前无运行请求时
if self.running_batch is None:
new_batch: Batch = await self.get_schedule_result(self.running_batch)
if new_batch is not None:
self.metric_client.histogram_observe("lightllm_batch_next_size", len(new_batch.reqs))
for req in new_batch.reqs:
self.metric_client.histogram_observe(
"lightllm_request_queue_duration_bucket", time.time() - req.start_time
)
self.stats_tool.count_prompt_tokens(new_batch)
self.running_batch = new_batch
await self._prefill_batch(self.running_batch)
self._filter_runing_batch()

# 激进调度控制
if not self.args.disable_aggressive_schedule:
self.has_wait_tokens = self.max_wait_tokens

elif self.is_multinode_and_multidp:
# 在多节点多 dp 的模式下,如果当前 running_batch 为None, 也需要不断的调用 decode 操作,
# 因为其他节点上的dp可能存在运行的请求,所以本节点也需要调用decode,推理后端的backend会
# padding 一些fake的请求来使推理过程可以正常完成。主要是给 deepseekv3 这种类型的大模型
# 使用的,其ep并行模式下需要所有节点协同。
await self._decode_batch(self.running_batch)

return

# 有运行请求,当持续decode的次数到达一个阈值,或者有上次预调度的结果存在的时。
if self.has_wait_tokens >= self.max_wait_tokens or self.schedule_task is not None:
new_mini_batch = await self.get_schedule_result(self.running_batch)
# 判断是否有新请求加入推理
# 激进调度满足,有新的推理batch就需要进行加入。
# 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。
if (self.schedule_new_batch is not None) and (
(not self.args.disable_aggressive_schedule) or (self.has_wait_tokens >= self.max_wait_tokens)
):
new_batch = self.schedule_new_batch
self.schedule_new_batch = None
self._add_new_batch_to_running_batch(new_batch=new_batch)
await self._prefill_batch(new_batch)
self.stats_tool.count_prompt_tokens(new_batch)
self._filter_reqs_from_running_batch()
self.has_wait_tokens = 0
if new_mini_batch is not None:

# 激进调度控制
if not self.args.disable_aggressive_schedule:
self.has_wait_tokens = self.max_wait_tokens

self.stats_tool.count_prompt_tokens(new_mini_batch)
await self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
self.running_batch.merge(new_mini_batch)
return

# Check if need pause some requests for decode.
for dp_index in range(self.dp_size_in_node):
Expand All @@ -374,51 +323,46 @@ async def _step(self):

# Decode
self.stats_tool.count_output_tokens(self.running_batch)
await self._decode_batch(self.running_batch)
self._filter_runing_batch()
await self._decode_batch()
self._filter_reqs_from_running_batch()
self.has_wait_tokens += 1
return

async def _prefill_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "prefill")
# 添加新请求
reqs = [r.to_router_rpc_obj() for r in batch.reqs]
await self.model_rpc_client.prefill(reqs)
batch.filter_out_finished_req(self.shm_req_manager)
self._send_detokenization_pack()

logger.debug(f"Prefill Batch: {batch.simple_log()} \n")
self.metric_client.histogram_observe(
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "prefill"
)
return

async def _decode_batch(self, batch: Batch):
start_time = time.time()
self.metric_client.counter_inc("lightllm_batch_inference_count", "decode")
async def _decode_batch(self):
self.schedule_event.set()
await self.model_rpc_client.decode()
# 在 self.is_multinode_and_multidp 为 True 时,传入的 batch 对象可能为 None。
if batch is not None:
batch.filter_out_finished_req(self.shm_req_manager)

self._send_detokenization_pack()
self.metric_client.histogram_observe(
"lightllm_batch_inference_duration_bucket", time.time() - start_time, "decode"
)
return

async def _pause_reqs(self, pasue_reqs):
async def _pause_reqs(self, pasue_reqs: List[Req]):
pasue_req_ids = [r.request_id for r in pasue_reqs]
await self.model_rpc_client.pause_reqs(pasue_req_ids)
return

def _filter_runing_batch(self):
if self.running_batch is not None and self.running_batch.is_clear():
self.running_batch = None
return
def _add_new_batch_to_running_batch(self, new_batch: Batch):
if self.running_batch is None:
self.running_batch = new_batch
else:
self.running_batch.merge(new_batch)
return

def _filter_reqs_from_running_batch(self):
if self.running_batch is not None:
self.running_batch.filter_out_finished_req(self.shm_req_manager)
if self.running_batch.is_clear():
self.running_batch = None
return

def _can_decode(self, batch: Batch, dp_index: int):
if self.is_pd_run_mode or self.is_safe_schedule:
if self.is_pd_run_mode or self.is_safe_schedule or batch is None:
return True
return (
batch.get_batch_decode_need_tokens()[dp_index] + self.get_used_tokens(dp_index) <= self.max_total_token_num
Expand All @@ -443,12 +387,35 @@ def get_used_tokens(self, dp_index):
return self.max_total_token_num - self.read_only_statics_mem_manager.get_unrefed_token_num(dp_index)

async def loop_for_netio_req(self):
recv_max_count = 64

while True:
recv_req: GroupReqIndexes = await self.recv_from_httpserver.recv_pyobj()
if isinstance(recv_req, GroupReqIndexes):
self.add_req(recv_req)
else:
assert False, f"Error Req Inf {recv_req}"
try:
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
for _ in range(recv_max_count):
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
if isinstance(recv_req, GroupReqIndexes):
self.add_req(recv_req)
else:
assert False, f"Error Req Inf {recv_req}"

# 当队列中存在较多的请求时,将一次接受的数量上调
recv_max_count = min(int(recv_max_count * 1.3), 256)

except zmq.ZMQError:
# 当队列已经开始清空的时候,将一次接受的数量下调
recv_max_count = 64

try:
await asyncio.wait_for(self.schedule_event.wait(), timeout=0.02)
except asyncio.TimeoutError:
pass

if self.schedule_event.is_set():
self.generate_new_batch()
self.schedule_event.clear()

return

def clean_up(self):
return
Expand All @@ -459,6 +426,13 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
graceful_registry(inspect.currentframe().f_code.co_name)
start_parent_check_thread()

def handle_exception(loop, context):
logger.exception(f"Router Caught exception: {str(context)}")

loop = asyncio.new_event_loop()
loop.set_exception_handler(handle_exception)
asyncio.set_event_loop(loop)

try:
router = RouterManager(
args,
Expand All @@ -467,7 +441,7 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
metric_port=metric_port,
)

asyncio.run(router.wait_to_model_ready())
loop.run_until_complete(router.wait_to_model_ready())
except:
import traceback
import sys
Expand All @@ -480,13 +454,6 @@ def start_router_process(args, router_port, detokenization_port, metric_port, pi
raise

pipe_writer.send("init ok")

def handle_exception(loop, context):
logger.exception(f"Router Caught exception: {str(context)}")

loop = asyncio.new_event_loop()
loop.set_exception_handler(handle_exception)
asyncio.set_event_loop(loop)
loop.create_task(router.loop_for_fwd())
loop.run_until_complete(router.loop_for_netio_req())
return
Loading