Skip to content

Commit 16e67e3

Browse files
committed
修复新的rank管理机制下的pd 分离实现。
1 parent 27a7045 commit 16e67e3

File tree

14 files changed

+197
-214
lines changed

14 files changed

+197
-214
lines changed

lightllm/common/basemodel/infer_lock.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,30 @@
1717

1818

1919
class InferStateLock:
20-
def __init__(self, name):
20+
def __init__(self, name, rank_in_dp:int, dp_rank_in_node:int, dp_world_size:int):
2121
self.infer_lock = threading.Lock()
22+
self.dp_rank_in_node = dp_rank_in_node
23+
# sync_world_size 应该是 min(dp_world_size, node_world_size)
24+
self.dp_world_size = dp_world_size
25+
self.rank_in_dp = rank_in_dp
2226
# 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧
23-
self.lock_tp_infos = SharedArray(f"{name}_lock_tp_infos", shape=(129,), dtype=np.int64)
27+
self.lock_tp_infos = SharedArray(f"{name}_dp_rank_{str(self.dp_rank_in_node)}_lock_tp_infos", shape=(self.dp_world_size + 1,), dtype=np.int64)
2428
self.lock_tp_infos.arr[:] = 0
25-
self.rank_id = dist.get_rank()
26-
self.world_size = dist.get_world_size()
2729

2830
def add_cur_mark(self):
29-
self.lock_tp_infos.arr[self.rank_id] += 1
31+
self.lock_tp_infos.arr[self.rank_in_dp] += 1
3032

3133
def get_cur_mark(self):
32-
return self.lock_tp_infos.arr[self.rank_id]
34+
return self.lock_tp_infos.arr[self.rank_in_dp]
3335

3436
def get_max_mark_in_group(self):
35-
return np.max(self.lock_tp_infos.arr[0 : self.world_size])
37+
return np.max(self.lock_tp_infos.arr[0 : self.dp_world_size])
3638

3739
def judge_cur_mark_equal_max_mark_in_group(self):
3840
return self.get_cur_mark() == self.get_max_mark_in_group()
3941

4042
def judge_mark_in_group_all_same(self):
41-
marks = self.lock_tp_infos.arr[0 : self.world_size]
43+
marks = self.lock_tp_infos.arr[0 : self.dp_world_size]
4244
return bool(np.all(marks == marks[0]))
4345

4446
def acquire_lock_and_update_cur_mark(self):
@@ -49,11 +51,11 @@ def release_lock(self):
4951
self.infer_lock.release()
5052

5153
def set_group_wait_mark(self):
52-
if self.rank_id == 0:
54+
if self.rank_in_dp == 0:
5355
self.lock_tp_infos.arr[-1] = 1
5456

5557
def unset_group_wait_mark(self):
56-
if self.rank_id == 0:
58+
if self.rank_in_dp == 0:
5759
self.lock_tp_infos.arr[-1] = 0
5860

5961
def get_group_wait_mark(self):
@@ -63,7 +65,7 @@ def get_group_wait_mark(self):
6365
@dataclass
6466
class G_Infer_Lock:
6567
obj: InferStateLock = None
66-
dp_size: int = None
68+
dp_world_size: int = None
6769

6870
def acquire(self):
6971
if self.obj is not None:
@@ -86,9 +88,8 @@ def release(self):
8688

8789
# 下面两个函数需要配对使用
8890
def acquire_lock_until_ready(nccl_group):
89-
# 在 deepseekv2 的tp dp 混合运行模式下, 不需要多个推理进程间做协调同步
90-
# 所以直接加锁,解锁即可
91-
if g_infer_state_lock.dp_size != 1:
91+
# 单卡一tp不用过度加锁
92+
if g_infer_state_lock.dp_world_size == 1:
9293
g_infer_state_lock.obj.infer_lock.acquire()
9394
return
9495

@@ -118,7 +119,7 @@ def release_acquired_lock():
118119
@dataclass
119120
class G_Router_Lock:
120121
"""
121-
保护pd分离模式下, 一些数据的操作
122+
保护pd分离模式下, 一些调度相关信息数据的操作
122123
"""
123124

124125
obj = None # 进程锁对象

lightllm/common/deepseek2_mem_manager.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def alloc_kv_move_buffer(self, max_req_total_len):
4040
return
4141

4242
def send_to_decode_node(
43-
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size: int
43+
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
4444
):
45-
assert dp_size == 1
45+
assert dp_size_in_node == 1
4646

4747
# 先将数据发送到指定的一张卡上的buffer,再发送。
4848
move_token_indexes = []
@@ -66,9 +66,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
6666
return move_buffer
6767

6868
def receive_from_prefill_node(
69-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
69+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
7070
):
71-
assert dp_size == 1
71+
assert dp_size_in_node == 1
7272

7373
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
7474
move_token_indexes = []
@@ -97,11 +97,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
9797
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
9898
return
9999

100-
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
100+
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int):
101101
"""
102102
使用 p2p triton kernel 进行数据复制和传输的实现方式。
103103
"""
104-
assert dp_size == 1
104+
assert dp_size_in_node == 1
105105

106106
move_token_indexes = []
107107
for task in move_tasks:
@@ -124,9 +124,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
124124
return move_buffer
125125

126126
def receive_from_prefill_node_p2p(
127-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
127+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
128128
):
129-
assert dp_size == 1
129+
assert dp_size_in_node == 1
130130

131131
move_token_indexes = []
132132
for task in move_tasks:

lightllm/common/mem_manager.py

Lines changed: 8 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,8 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8383
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
8484
return
8585

86-
def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
87-
"""
88-
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
89-
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
90-
被真正使用
91-
"""
92-
assert dp_size == 1
86+
def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int):
87+
assert dp_size_in_node == 1
9388

9489
# 先将数据发送到指定的一张卡上的buffer,再发送。
9590

@@ -123,14 +118,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
123118
return move_buffer
124119

125120
def receive_from_prefill_node(
126-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
121+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
127122
):
128-
"""
129-
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
130-
普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
131-
被真正使用
132-
"""
133-
assert dp_size == 1
123+
assert dp_size_in_node == 1
134124

135125
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
136126

@@ -160,11 +150,11 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
160150
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
161151
return
162152

163-
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
153+
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int):
164154
"""
165155
使用 p2p triton kernel 进行数据复制和传输的实现方式。
166156
"""
167-
assert dp_size == 1
157+
assert dp_size_in_node == 1
168158

169159
# 先将数据发送到指定的一张卡上的buffer,再发送。
170160

@@ -190,9 +180,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
190180
return move_buffer
191181

192182
def receive_from_prefill_node_p2p(
193-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
183+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
194184
):
195-
assert dp_size == 1
185+
assert dp_size_in_node == 1
196186

197187
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
198188

lightllm/server/api_start.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,14 +164,14 @@ def normal_or_p_d_start(args):
164164
# 捕获到端口设置冲突的问题
165165
ports_locker = PortLocker(already_uesd_ports)
166166
ports_locker.lock_port()
167-
167+
168+
node_world_size = args.tp // args.nnodes
168169
can_use_ports = alloc_can_use_network_port(
169-
num=6 + args.tp + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
170+
num=6 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
170171
)
171172
logger.info(f"alloced ports: {can_use_ports}")
172173
router_port, detokenization_port, detokenization_pub_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
173-
model_rpc_ports = can_use_ports[6 : 6 + args.tp]
174-
can_use_ports = can_use_ports[6 + args.tp :]
174+
can_use_ports = can_use_ports[6:]
175175

176176
visual_model_tp_ports = []
177177
for _ in range(args.visual_dp):
@@ -188,7 +188,7 @@ def normal_or_p_d_start(args):
188188
args.metric_port = metric_port
189189

190190
# 申请在 p d 分离模式下,会用的端口
191-
args.pd_tp_infer_rpyc_ports = can_use_ports[0 : args.tp]
191+
args.pd_node_infer_rpyc_ports = can_use_ports[0 : node_world_size]
192192
# p d 分离模式下用于标识节点的id
193193
args.pd_node_id = uuid.uuid4().int
194194
# p 节点用来建立torch kv 传输分布组的可用端口范围
@@ -231,7 +231,7 @@ def normal_or_p_d_start(args):
231231
process_manager.start_submodule_processes(
232232
start_funcs=[start_router_process, start_detokenization_process],
233233
start_args=[
234-
(args, router_port, detokenization_port, model_rpc_ports, metric_port),
234+
(args, router_port, detokenization_port, metric_port),
235235
(args, detokenization_port, detokenization_pub_port),
236236
],
237237
)

lightllm/server/router/manager.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040

4141
class RouterManager:
42-
def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metric_port):
42+
def __init__(self, args, router_port, detokenization_port, metric_port):
4343
self.args = args
4444
self.model_weightdir = args.model_dir
4545
self.world_size = args.tp
@@ -81,8 +81,7 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
8181

8282
self.send_to_detokenization = context.socket(zmq.PUSH)
8383
self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
84-
self.model_rpc_ports = model_rpc_ports
85-
84+
8685
if self.is_multinode_tp:
8786
self.mulitnode_group = dist.init_process_group(
8887
backend="gloo",
@@ -173,7 +172,7 @@ async def wait_to_model_ready(self):
173172
"batch_max_tokens": self.args.batch_max_tokens,
174173
"quant_type": self.args.quant_type,
175174
"quant_cfg": self.args.quant_cfg,
176-
"pd_rpyc_ports": self.args.pd_tp_infer_rpyc_ports, # 非 pd 模式可以不设置
175+
"pd_rpyc_ports": self.args.pd_node_infer_rpyc_ports, # 非 pd 模式可以不设置
177176
}
178177

179178
await self.model_rpc_client.init_model(kvargs=kvargs)
@@ -416,7 +415,7 @@ def clean_up(self):
416415
return
417416

418417

419-
def start_router_process(args, router_port, detokenization_port, model_rpc_ports, metric_port, pipe_writer):
418+
def start_router_process(args, router_port, detokenization_port, metric_port, pipe_writer):
420419
# 注册 graceful 退出的处理
421420
graceful_registry(inspect.currentframe().f_code.co_name)
422421
start_parent_check_thread()
@@ -426,7 +425,6 @@ def start_router_process(args, router_port, detokenization_port, model_rpc_ports
426425
args,
427426
router_port=router_port,
428427
detokenization_port=detokenization_port,
429-
model_rpc_ports=model_rpc_ports,
430428
metric_port=metric_port,
431429
)
432430

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,8 @@ def init_model(self, kvargs):
100100

101101
# 为 p d 分离模式添加的全局锁管理,用于做一些同步操作。 一定需要在
102102
# init_process_group 之后调用
103-
g_infer_state_lock.obj = InferStateLock(name=get_unique_server_name())
104-
g_infer_state_lock.dp_size = self.dp_size
103+
g_infer_state_lock.obj = InferStateLock(name=get_unique_server_name(), rank_in_dp=self.rank_in_dp, dp_rank_in_node=self.dp_rank_in_node, dp_world_size=self.dp_world_size)
104+
g_infer_state_lock.dp_world_size = self.dp_world_size
105105
self.infer_state_lock = g_infer_state_lock
106106
# 防止InferStateLock 中的全局共享信息被重复异常初始化,导致同步异常的问题。
107107
# 所以做一次barrier等待

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/decode_node_impl/decode_impl.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, g_router_lock
1919
from .decode_task_cache import g_success_kv_move_task_cache, KVMoveTask
2020
from lightllm.utils.device_utils import kv_trans_use_p2p
21+
from lightllm.utils.envs_utils import get_unique_server_name
2122

2223
logger = init_logger(__name__)
2324

@@ -29,10 +30,16 @@ def __init__(self, info_queue: mp.Queue, mem_queue: mp.Queue) -> None:
2930
self.mem_queue: mp.Queue = mem_queue
3031

3132
def init_custom(self):
32-
self.lock_nccl_group = dist.new_group(backend="gloo")
33+
ranks = []
34+
for i in range(self.dp_world_size):
35+
ranks.append(i + self.global_dp_rank * self.dp_world_size)
36+
37+
self.lock_nccl_group = dist.new_group(ranks=ranks, backend="gloo")
38+
logger.info(f"lock_nccl_group ranks {self.lock_nccl_group.get_rank()}")
39+
3340
from .decode_infer_rpyc import PDDecodeInferRpcServer
3441

35-
socket_path = f"/tmp/decode_node_infer_rpyc_{self.pd_rpyc_ports[self.tp_rank]}"
42+
socket_path = f"/tmp/{get_unique_server_name()}_decode_node_infer_rpyc_{self.pd_rpyc_ports[self.rank_in_node]}"
3643
if os.path.exists(socket_path):
3744
os.remove(socket_path)
3845

@@ -141,8 +148,8 @@ def post_init(self, uninit_reqs: List[InferReq]):
141148

142149
if self.is_master_in_dp:
143150
with g_router_lock.obj:
144-
self.shared_token_load.add_frozened_token_count(-remove_count, self.tp_rank)
145-
self.shared_token_load.add_estimated_peak_token_count(estimated_peak_token_count, self.tp_rank)
151+
self.shared_token_load.add_frozened_token_count(-remove_count, self.dp_rank_in_node)
152+
self.shared_token_load.add_estimated_peak_token_count(estimated_peak_token_count, self.dp_rank_in_node)
146153
return
147154

148155
def filter_finished_reqs(self, finished_reqs: List[InferReq]):

0 commit comments

Comments
 (0)