Skip to content

Commit 2d46245

Browse files
committed
Merge branch 'wzj_router' of https://github.com/ModelTC/lightllm into wzj_router
2 parents 678bb5f + 241ec63 commit 2d46245

File tree

15 files changed

+195
-414
lines changed

15 files changed

+195
-414
lines changed

lightllm/server/api_start.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def normal_or_p_d_start(args):
159159
args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp]
160160

161161
if args.disable_chunked_prefill:
162+
args.chunked_prefill_size = args.max_req_total_len
162163
# 普通模式下
163164
if args.batch_max_tokens is None:
164165
args.batch_max_tokens = args.max_req_total_len

lightllm/server/httpserver/manager.py

Lines changed: 8 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@ def __init__(
5757
self._shm_lock_pool = AtomicShmArrayLock(f"{get_unique_server_name()}_lightllm_resource_lock", 1)
5858
self._resource_lock = AsyncLock(self._shm_lock_pool.get_lock_context(0))
5959
self.node_rank = args.node_rank
60-
self.transfer_lock = asyncio.Lock() # the lock for transfer to next module in multi node mode.
6160
self.disable_abort = args.nnodes > 1 and args.dp == 1 # mulitnode dp=1 mode, disable abort
6261
self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
6362
self.is_multinode_tp_master = args.dp == 1 and args.nnodes > 1 and args.node_rank == 0
@@ -202,25 +201,20 @@ def tokens(self, prompt, multimodal_params, samping_params: SamplingParams, kwar
202201

203202
async def loop_for_request(self):
204203
assert self.args.node_rank > 0
205-
tasks = []
206-
self.request_order_queue = []
207204
while True:
208205
(
209206
prompt,
210207
sampling_params,
211208
multimodal_params,
212209
) = await self.multinode_req_manager.recv_pyobj()
213-
self.request_order_queue.append(sampling_params.group_request_id)
214210
results_generator = self.generate(prompt, sampling_params, multimodal_params, None)
215211

216212
async def generate_wrapper(results_generator):
217213
async for _, _, _, _ in results_generator:
218214
pass
219215

220-
tasks.append(asyncio.create_task(generate_wrapper(results_generator)))
221-
# cleanup
222-
while len(tasks) > 0 and tasks[0].done():
223-
tasks.pop(0)
216+
asyncio.create_task(generate_wrapper(results_generator))
217+
return
224218

225219
def alloc_req_id(self, sampling_params, is_health_req: bool = False):
226220
# 请求的 id 可以由外部传入,也可以由内部生成,但是由外部传入的时候,要自己保证全局唯一性
@@ -413,32 +407,13 @@ async def transfer_to_next_module_or_node(
413407
original_multimodal_params: MultimodalParams,
414408
group_req_objs: Optional[GroupReqObjs] = None,
415409
):
416-
# 多节点纯tp 运行模式下,master 节点需要将请求按照可控的顺序转发给slave节点,
417-
# 同时转发给salve节点的时候,要保证master节点按照转发的顺序转发给next_module
418-
# 所以需要锁的控制。
410+
# 多节点纯tp 运行模式下,master 节点需要将请求转发给slave节点.
419411
if self.is_multinode_tp_master:
420-
async with self.transfer_lock:
421-
for sender in self.multinode_req_manager:
422-
sender.send_pyobj(
423-
(prompt, sampling_params, original_multimodal_params),
424-
protocol=pickle.HIGHEST_PROTOCOL,
425-
)
426-
await self.transfer_to_next_module(group_req_objs)
427-
return
428-
# 多节点纯tp 的slave节点,需要按照接受到请求的顺序转发,这需要锁和排队机制来保证。
429-
# self.request_order_queue 实现了一种简单的排队取出机制,这样master 和 slave
430-
# 节点的请求到达各自节点的router的顺序才是一致的,才能完成同步同态调度。
431-
if self.is_multinode_tp_slave:
432-
while True:
433-
if self.request_order_queue and self.request_order_queue[0] != group_req_objs.group_req_id:
434-
await asyncio.sleep(0.002)
435-
continue
436-
else:
437-
async with self.transfer_lock:
438-
await self.transfer_to_next_module(group_req_objs)
439-
self.request_order_queue.pop(0)
440-
break
441-
return
412+
for sender in self.multinode_req_manager:
413+
sender.send_pyobj(
414+
(prompt, sampling_params, original_multimodal_params),
415+
protocol=pickle.HIGHEST_PROTOCOL,
416+
)
442417

443418
await self.transfer_to_next_module(group_req_objs)
444419
return

lightllm/server/router/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,10 @@ def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
5454
self.id_to_reqs = {req.request_id: req for req in self.reqs}
5555
return
5656

57-
def pop_req(self, req_id):
57+
def pop_req(self, req_id) -> Req:
5858
self.reqs = [req for req in self.reqs if req.request_id != req_id]
59-
self.id_to_reqs.pop(req_id)
60-
return
59+
req = self.id_to_reqs.pop(req_id)
60+
return req
6161

6262
def is_clear(self):
6363
return len(self.reqs) == 0

lightllm/server/router/manager.py

Lines changed: 86 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd
1919
from lightllm.server.core.objs import ShmReqManager, StartArgs
2020
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
21-
from .stats import Stats
2221
from .shm_reqs_io_buffer import ShmReqsIOBuffer
2322
from lightllm.utils.log_utils import init_logger, log_time_ready
2423
from lightllm.server.router.token_load import TokenLoad
@@ -45,6 +44,8 @@ def __init__(self, args: StartArgs, router_port, detokenization_port, metric_por
4544
# 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
4645
self.dp_size_in_node = max(1, args.dp // self.nnodes)
4746
self.is_multinode_tp = args.nnodes > 1 and args.dp == 1
47+
self.is_multinode_tp_master = self.is_multinode_tp and args.node_rank == 0
48+
self.is_multinode_tp_slave = self.is_multinode_tp and args.node_rank != 0
4849
self.is_multinode_and_multidp = args.nnodes > 1 and args.dp > 1
4950
# 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐
5051
self.is_safe_schedule = args.router_token_ratio == 0.0
@@ -254,6 +255,8 @@ async def _step(self):
254255
"""
255256
事件处理循环
256257
"""
258+
# 接受新请求,并尝试调度
259+
await self._recv_new_reqs_and_schedule()
257260
# 判断是否有新请求加入推理
258261
# 激进调度满足,有新的推理batch就需要进行加入。
259262
# 或者延迟step的步数满足了当前条件,也需要进行新的推理batch的加入。
@@ -357,44 +360,96 @@ def _add_req(self, group_req_indexes: GroupReqIndexes):
357360
return
358361

359362
def _generate_new_batch(self):
360-
limit_router_queue_length = None
361-
if self.is_multinode_tp:
362-
# 使用 all_reduce 获取最小值
363-
limit_router_queue_length = len(self.req_queue.waiting_req_list)
364-
limit_router_queue_length_tensor = torch.tensor(limit_router_queue_length, dtype=torch.int32, device="cpu")
365-
dist.all_reduce(limit_router_queue_length_tensor, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
366-
limit_router_queue_length = limit_router_queue_length_tensor.item()
367-
368363
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
369364
new_batch = self.req_queue.generate_new_batch(
370-
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch), limit_router_queue_length
365+
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
371366
)
372367
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
373368
return
374369

375-
async def loop_for_netio_req(self):
376-
recv_max_count = 64
370+
def _multinode_tp_generate_new_batch(self):
371+
dist.barrier(group=self.mulitnode_group)
377372

378-
while True:
379-
try:
380-
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
381-
for _ in range(recv_max_count):
382-
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
383-
if isinstance(recv_req, GroupReqIndexes):
384-
self._add_req(recv_req)
385-
else:
386-
assert False, f"Error Req Inf {recv_req}"
387-
388-
# 当队列中存在较多的请求时,将一次接受的数量上调
389-
recv_max_count = min(int(recv_max_count * 1.3), 256)
390-
391-
except zmq.ZMQError:
392-
# 当队列已经开始清空的时候,将一次接受的数量下调
393-
recv_max_count = 64
373+
# 调度的时候需要考虑当前运行的batch,和调度了但是暂时还没有推理的部分请求。
374+
if self.is_multinode_tp_master:
375+
new_batch = self.req_queue.generate_new_batch(
376+
Batch.merge_two_batch(self.running_batch, self.schedule_new_batch)
377+
)
378+
if new_batch is not None:
379+
req_ids = [req.request_id for req in new_batch.reqs]
380+
else:
381+
req_ids = []
382+
dist.broadcast_object_list([len(req_ids)], src=0, group=self.mulitnode_group)
383+
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
384+
req_id_select_mark = [1 for _ in range(len(req_ids))]
385+
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
386+
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
387+
back_req_list = []
388+
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
389+
if select == 0:
390+
req = new_batch.pop_req(req_id)
391+
back_req_list.append(req)
392+
self.req_queue.waiting_req_list = back_req_list + self.req_queue.waiting_req_list
393+
if new_batch.is_clear():
394+
new_batch = None
395+
else:
396+
req_nums = [None]
397+
dist.broadcast_object_list(req_nums, src=0, group=self.mulitnode_group)
398+
req_num = req_nums[0]
399+
req_ids = [None for _ in range(req_num)]
400+
dist.broadcast_object_list(req_ids, src=0, group=self.mulitnode_group)
401+
all_req_id_set = set([req.request_id for req in self.req_queue.waiting_req_list])
402+
req_id_select_mark = []
403+
for req_id in req_ids:
404+
req_id_select_mark.append(1 if req_id in all_req_id_set else 0)
405+
req_id_select_mark = torch.tensor(req_id_select_mark, dtype=torch.int32, device="cpu")
406+
dist.all_reduce(req_id_select_mark, op=dist.ReduceOp.MIN, group=self.mulitnode_group)
407+
select_req_ids = []
408+
for req_id, select in zip(req_ids, req_id_select_mark.numpy()):
409+
if select == 1:
410+
select_req_ids.append(req_id)
411+
412+
select_reqs = []
413+
for req_id in select_req_ids:
414+
for req in self.req_queue.waiting_req_list:
415+
if req.request_id == req_id:
416+
select_reqs.append(req)
417+
418+
for req in select_reqs:
419+
self.req_queue.waiting_req_list.remove(req)
420+
if select_reqs:
421+
new_batch = Batch(-1, reqs=select_reqs, dp_size_in_node=self.dp_size_in_node)
422+
else:
423+
new_batch = None
394424

395-
await asyncio.sleep(0.02)
425+
self.schedule_new_batch = Batch.merge_two_batch(self.schedule_new_batch, new_batch)
426+
427+
dist.barrier(group=self.mulitnode_group)
428+
return
429+
430+
async def _recv_new_reqs_and_schedule(self):
431+
if not hasattr(self, "recv_max_count"):
432+
self.recv_max_count = 64
433+
434+
try:
435+
# 一次最多从 zmq 中取 recv_max_count 个请求,防止 zmq 队列中请求数量过多导致阻塞了主循环。
436+
for _ in range(self.recv_max_count):
437+
recv_req: GroupReqIndexes = self.recv_from_httpserver.recv_pyobj(zmq.NOBLOCK)
438+
if isinstance(recv_req, GroupReqIndexes):
439+
self._add_req(recv_req)
440+
else:
441+
assert False, f"Error Req Inf {recv_req}"
396442

397-
# 只有当推理侧没有发生暂停的时候,才执行新的调度
443+
# 当队列中存在较多的请求时,将一次接受的数量上调
444+
self.recv_max_count = min(int(self.recv_max_count * 1.3), 256)
445+
446+
except zmq.ZMQError:
447+
# 当队列已经开始清空的时候,将一次接受的数量下调
448+
self.recv_max_count = 64
449+
450+
if self.is_multinode_tp:
451+
self._multinode_tp_generate_new_batch()
452+
else:
398453
if self._get_paused_req_num() == 0:
399454
self._generate_new_batch()
400455
return
@@ -436,6 +491,5 @@ def handle_exception(loop, context):
436491
raise
437492

438493
pipe_writer.send("init ok")
439-
loop.create_task(router.loop_for_fwd())
440-
loop.run_until_complete(router.loop_for_netio_req())
494+
loop.run_until_complete(router.loop_for_fwd())
441495
return

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self) -> None:
4848
self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap
4949
self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap
5050

51-
# 控制分类的参数变量
51+
# 控制 _get_classed_reqs 分类的参数变量,不同的 backend 具有可能需要不同的分类运行条件。
5252
self.classed_req_no_decode = False
5353
self.classed_req_strict_prefill = False
5454
pass
@@ -74,6 +74,7 @@ def init_model(self, kvargs):
7474
self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache
7575
self.eos_id: List[int] = kvargs.get("eos_id", [2])
7676
self.disable_cudagraph = self.args.disable_cudagraph
77+
self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1
7778

7879
self.logger = init_logger(__name__)
7980

@@ -166,17 +167,29 @@ def init_model(self, kvargs):
166167
[0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False
167168
)
168169

170+
# 用于协同读取 ShmReqsIOBuffer 中的请求信息的通信tensor和通信组对象。
169171
self.node_broadcast_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
170172
self.node_nccl_group = create_new_group_for_current_node("nccl")
171173

174+
# 用于在多节点tp模式下协同读取 ShmReqsIOBuffer 中的请求信息的通信tensor和通信组对象。
175+
if self.is_multinode_tp:
176+
self.multinode_tp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda")
177+
self.multinode_tp_all_gather_tensor = torch.tensor(
178+
[0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False
179+
)
180+
self.multinode_tp_nccl_group = dist.new_group(
181+
[rank for rank in range(self.global_world_size)], backend="nccl"
182+
)
183+
172184
self.init_custom()
173185
self.shm_reqs_io_buffer = ShmReqsIOBuffer()
174186

175187
# 开启 mtp 模式,需要完成mtp model的初始化
176188
if self.args.mtp_mode:
177189
self.init_mtp_draft_model(kvargs)
178190

179-
# 启动infer_loop_thread
191+
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
192+
# 可以降低 cpu overhead,大幅提升gpu得使用率。
180193
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
181194
self.infer_loop_thread.start()
182195
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True)
@@ -239,6 +252,13 @@ def init_mtp_draft_model(self, main_kvargs: dict):
239252
return
240253

241254
def _try_read_new_reqs(self):
255+
if self.is_multinode_tp:
256+
self._try_read_new_reqs_multinode_tp()
257+
else:
258+
self._try_read_new_reqs_normal()
259+
return
260+
261+
def _try_read_new_reqs_normal(self):
242262
if self.is_master_in_node:
243263
if self.shm_reqs_io_buffer.is_ready():
244264
self.node_broadcast_tensor.fill_(1)
@@ -247,16 +267,42 @@ def _try_read_new_reqs(self):
247267
dist.broadcast(self.node_broadcast_tensor, src=0, group=self.node_nccl_group, async_op=False)
248268
new_buffer_is_ready = self.node_broadcast_tensor.detach().item()
249269
if new_buffer_is_ready:
250-
cmds: List = self.shm_reqs_io_buffer.read_obj()
251-
self.shm_reqs_io_buffer.sub_state()
252-
if cmds:
253-
if isinstance(cmds[0], AbortedReqCmd):
254-
for obj in cmds:
255-
if obj.req_id in g_infer_context.requests_mapping:
256-
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
257-
req.infer_aborted = True
258-
else:
259-
self._init_reqs(reqs=cmds)
270+
self._read_reqs_buffer_and_init_reqs()
271+
return
272+
273+
def _try_read_new_reqs_multinode_tp(self):
274+
"""
275+
多节点tp模式下,需要协调所有rank的行为同步。
276+
"""
277+
if self.shm_reqs_io_buffer.is_ready():
278+
self.multinode_tp_gather_item_tensor.fill_(1)
279+
else:
280+
self.multinode_tp_gather_item_tensor.fill_(0)
281+
dist.all_gather_into_tensor(
282+
self.multinode_tp_all_gather_tensor,
283+
self.multinode_tp_gather_item_tensor,
284+
group=self.multinode_tp_nccl_group,
285+
async_op=False,
286+
)
287+
new_buffer_is_readys = self.multinode_tp_all_gather_tensor.detach().cpu().numpy()
288+
new_buffer_is_ready = np.all(new_buffer_is_readys == 1)
289+
290+
if new_buffer_is_ready:
291+
self._read_reqs_buffer_and_init_reqs()
292+
return
293+
294+
def _read_reqs_buffer_and_init_reqs(self):
295+
cmds: List = self.shm_reqs_io_buffer.read_obj()
296+
self.shm_reqs_io_buffer.sub_state()
297+
if cmds:
298+
if isinstance(cmds[0], AbortedReqCmd):
299+
for obj in cmds:
300+
obj: AbortedReqCmd = obj
301+
if obj.req_id in g_infer_context.requests_mapping:
302+
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
303+
req.infer_aborted = True
304+
else:
305+
self._init_reqs(reqs=cmds)
260306
return
261307

262308
# 一些可以复用的通用功能函数

0 commit comments

Comments
 (0)