Skip to content

Commit 241ec63

Browse files
committed
fix
1 parent b172eaf commit 241ec63

File tree

2 files changed

+58
-13
lines changed

2 files changed

+58
-13
lines changed

lightllm/server/router/manager.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from lightllm.server.core.objs.io_objs import GroupReqIndexes, AbortedReqCmd
1919
from lightllm.server.core.objs import ShmReqManager, StartArgs
2020
from .dynamic_prompt.radix_cache import RadixCacheReadOnlyClient
21-
from .stats import Stats
2221
from .shm_reqs_io_buffer import ShmReqsIOBuffer
2322
from lightllm.utils.log_utils import init_logger, log_time_ready
2423
from lightllm.server.router.token_load import TokenLoad

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 58 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def __init__(self) -> None:
4747
self.enable_decode_microbatch_overlap = get_env_start_args().enable_decode_microbatch_overlap
4848
self.enable_prefill_microbatch_overlap = get_env_start_args().enable_prefill_microbatch_overlap
4949

50-
# 控制分类的参数变量
50+
# 控制 _get_classed_reqs 分类的参数变量,不同的 backend 具有可能需要不同的分类运行条件。
5151
self.classed_req_no_decode = False
5252
self.classed_req_strict_prefill = False
5353
pass
@@ -73,6 +73,7 @@ def init_model(self, kvargs):
7373
self.use_dynamic_prompt_cache = not self.args.disable_dynamic_prompt_cache
7474
self.eos_id: List[int] = kvargs.get("eos_id", [2])
7575
self.disable_cudagraph = self.args.disable_cudagraph
76+
self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1
7677

7778
self.logger = init_logger(__name__)
7879

@@ -165,17 +166,29 @@ def init_model(self, kvargs):
165166
[0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False
166167
)
167168

169+
# 用于协同读取 ShmReqsIOBuffer 中的请求信息的通信tensor和通信组对象。
168170
self.node_broadcast_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False)
169171
self.node_nccl_group = create_new_group_for_current_node("nccl")
170172

173+
# 用于在多节点tp模式下协同读取 ShmReqsIOBuffer 中的请求信息的通信tensor和通信组对象。
174+
if self.is_multinode_tp:
175+
self.multinode_tp_gather_item_tensor = torch.tensor([0], dtype=torch.int32, device="cuda")
176+
self.multinode_tp_all_gather_tensor = torch.tensor(
177+
[0 for _ in range(self.global_world_size)], dtype=torch.int32, device="cuda", requires_grad=False
178+
)
179+
self.multinode_tp_nccl_group = dist.new_group(
180+
[rank for rank in range(self.global_world_size)], backend="nccl"
181+
)
182+
171183
self.init_custom()
172184
self.shm_reqs_io_buffer = ShmReqsIOBuffer()
173185

174186
# 开启 mtp 模式,需要完成mtp model的初始化
175187
if self.args.mtp_mode:
176188
self.init_mtp_draft_model(kvargs)
177189

178-
# 启动infer_loop_thread
190+
# 启动infer_loop_thread, 启动两个线程进行推理,对于具备双batch推理折叠得场景
191+
# 可以降低 cpu overhead,大幅提升gpu得使用率。
179192
self.infer_loop_thread = threading.Thread(target=self.infer_loop, daemon=True)
180193
self.infer_loop_thread.start()
181194
self.infer_loop_thread1 = threading.Thread(target=self.infer_loop, daemon=True)
@@ -238,6 +251,13 @@ def init_mtp_draft_model(self, main_kvargs: dict):
238251
return
239252

240253
def _try_read_new_reqs(self):
254+
if self.is_multinode_tp:
255+
self._try_read_new_reqs_multinode_tp()
256+
else:
257+
self._try_read_new_reqs_normal()
258+
return
259+
260+
def _try_read_new_reqs_normal(self):
241261
if self.is_master_in_node:
242262
if self.shm_reqs_io_buffer.is_ready():
243263
self.node_broadcast_tensor.fill_(1)
@@ -246,16 +266,42 @@ def _try_read_new_reqs(self):
246266
dist.broadcast(self.node_broadcast_tensor, src=0, group=self.node_nccl_group, async_op=False)
247267
new_buffer_is_ready = self.node_broadcast_tensor.detach().item()
248268
if new_buffer_is_ready:
249-
cmds: List = self.shm_reqs_io_buffer.read_obj()
250-
self.shm_reqs_io_buffer.sub_state()
251-
if cmds:
252-
if isinstance(cmds[0], AbortedReqCmd):
253-
for obj in cmds:
254-
if obj.req_id in g_infer_context.requests_mapping:
255-
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
256-
req.infer_aborted = True
257-
else:
258-
self._init_reqs(reqs=cmds)
269+
self._read_reqs_buffer_and_init_reqs()
270+
return
271+
272+
def _try_read_new_reqs_multinode_tp(self):
273+
"""
274+
多节点tp模式下,需要协调所有rank的行为同步。
275+
"""
276+
if self.shm_reqs_io_buffer.is_ready():
277+
self.multinode_tp_gather_item_tensor.fill_(1)
278+
else:
279+
self.multinode_tp_gather_item_tensor.fill_(0)
280+
dist.all_gather_into_tensor(
281+
self.multinode_tp_all_gather_tensor,
282+
self.multinode_tp_gather_item_tensor,
283+
group=self.multinode_tp_nccl_group,
284+
async_op=False,
285+
)
286+
new_buffer_is_readys = self.multinode_tp_all_gather_tensor.detach().cpu().numpy()
287+
new_buffer_is_ready = np.all(new_buffer_is_readys == 1)
288+
289+
if new_buffer_is_ready:
290+
self._read_reqs_buffer_and_init_reqs()
291+
return
292+
293+
def _read_reqs_buffer_and_init_reqs(self):
294+
cmds: List = self.shm_reqs_io_buffer.read_obj()
295+
self.shm_reqs_io_buffer.sub_state()
296+
if cmds:
297+
if isinstance(cmds[0], AbortedReqCmd):
298+
for obj in cmds:
299+
obj: AbortedReqCmd = obj
300+
if obj.req_id in g_infer_context.requests_mapping:
301+
req: InferReq = g_infer_context.requests_mapping[obj.req_id]
302+
req.infer_aborted = True
303+
else:
304+
self._init_reqs(reqs=cmds)
259305
return
260306

261307
# 一些可以复用的通用功能函数

0 commit comments

Comments
 (0)