Skip to content

Commit ba6e7f3

Browse files
hiworldwzjjayfeather9rootshihaobairoot
authored
Multinode (#752)
Co-authored-by: wufeiyang <[email protected]> Co-authored-by: root <root@pt-290ac8041d114af0b1647509a5544872-master-0.pt-290ac8041d114af0b1647509a5544872.ns-devoversea-d41e68bd.svc.cluster.local> Co-authored-by: shihaobai <[email protected]> Co-authored-by: root <root@pt-511f450a52c24c2d9df9b20f0c8ebdb7-master-0.pt-511f450a52c24c2d9df9b20f0c8ebdb7.ns-devoversea-d41e68bd.svc.cluster.local> Co-authored-by: Feiyang Wu <[email protected]>
1 parent d403ad7 commit ba6e7f3

40 files changed

+573
-449
lines changed

lightllm/common/basemodel/infer_lock.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,28 +17,32 @@
1717

1818

1919
class InferStateLock:
20-
def __init__(self, name):
20+
def __init__(self, name, rank_in_dp: int, dp_rank_in_node: int, dp_world_size: int):
2121
self.infer_lock = threading.Lock()
22+
self.dp_rank_in_node = dp_rank_in_node
23+
# sync_world_size 应该是 min(dp_world_size, node_world_size)
24+
self.dp_world_size = dp_world_size
25+
self.rank_in_dp = rank_in_dp
2226
# 默认开 128 tp 的空间, 现在应该没什么卡能开这么大的tp 吧
23-
self.lock_tp_infos = SharedArray(f"{name}_lock_tp_infos", shape=(129,), dtype=np.int64)
27+
self.lock_tp_infos = SharedArray(
28+
f"{name}_dp_rank_{str(self.dp_rank_in_node)}_lock_tp_infos", shape=(self.dp_world_size + 1,), dtype=np.int64
29+
)
2430
self.lock_tp_infos.arr[:] = 0
25-
self.rank_id = dist.get_rank()
26-
self.world_size = dist.get_world_size()
2731

2832
def add_cur_mark(self):
29-
self.lock_tp_infos.arr[self.rank_id] += 1
33+
self.lock_tp_infos.arr[self.rank_in_dp] += 1
3034

3135
def get_cur_mark(self):
32-
return self.lock_tp_infos.arr[self.rank_id]
36+
return self.lock_tp_infos.arr[self.rank_in_dp]
3337

3438
def get_max_mark_in_group(self):
35-
return np.max(self.lock_tp_infos.arr[0 : self.world_size])
39+
return np.max(self.lock_tp_infos.arr[0 : self.dp_world_size])
3640

3741
def judge_cur_mark_equal_max_mark_in_group(self):
3842
return self.get_cur_mark() == self.get_max_mark_in_group()
3943

4044
def judge_mark_in_group_all_same(self):
41-
marks = self.lock_tp_infos.arr[0 : self.world_size]
45+
marks = self.lock_tp_infos.arr[0 : self.dp_world_size]
4246
return bool(np.all(marks == marks[0]))
4347

4448
def acquire_lock_and_update_cur_mark(self):
@@ -49,11 +53,11 @@ def release_lock(self):
4953
self.infer_lock.release()
5054

5155
def set_group_wait_mark(self):
52-
if self.rank_id == 0:
56+
if self.rank_in_dp == 0:
5357
self.lock_tp_infos.arr[-1] = 1
5458

5559
def unset_group_wait_mark(self):
56-
if self.rank_id == 0:
60+
if self.rank_in_dp == 0:
5761
self.lock_tp_infos.arr[-1] = 0
5862

5963
def get_group_wait_mark(self):
@@ -63,7 +67,7 @@ def get_group_wait_mark(self):
6367
@dataclass
6468
class G_Infer_Lock:
6569
obj: InferStateLock = None
66-
dp_size: int = None
70+
dp_world_size: int = None
6771

6872
def acquire(self):
6973
if self.obj is not None:
@@ -86,9 +90,8 @@ def release(self):
8690

8791
# 下面两个函数需要配对使用
8892
def acquire_lock_until_ready(nccl_group):
89-
# 在 deepseekv2 的tp dp 混合运行模式下, 不需要多个推理进程间做协调同步
90-
# 所以直接加锁,解锁即可
91-
if g_infer_state_lock.dp_size != 1:
93+
# 单卡一tp不用过度加锁
94+
if g_infer_state_lock.dp_world_size == 1:
9295
g_infer_state_lock.obj.infer_lock.acquire()
9396
return
9497

@@ -118,7 +121,7 @@ def release_acquired_lock():
118121
@dataclass
119122
class G_Router_Lock:
120123
"""
121-
保护pd分离模式下, 一些数据的操作
124+
保护pd分离模式下, 一些调度相关信息数据的操作
122125
"""
123126

124127
obj = None # 进程锁对象

lightllm/common/deepseek2_mem_manager.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ def alloc_kv_move_buffer(self, max_req_total_len):
4040
return
4141

4242
def send_to_decode_node(
43-
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size: int
43+
self, move_tasks: List[KVMoveTask], mem_managers: List["Deepseek2MemoryManager"], dp_size_in_node: int
4444
):
45-
assert dp_size == 1
45+
assert dp_size_in_node == 1
4646

4747
# 先将数据发送到指定的一张卡上的buffer,再发送。
4848
move_token_indexes = []
@@ -66,9 +66,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
6666
return move_buffer
6767

6868
def receive_from_prefill_node(
69-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
69+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
7070
):
71-
assert dp_size == 1
71+
assert dp_size_in_node == 1
7272

7373
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
7474
move_token_indexes = []
@@ -97,11 +97,13 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
9797
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
9898
return
9999

100-
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
100+
def send_to_decode_node_p2p(
101+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
102+
):
101103
"""
102104
使用 p2p triton kernel 进行数据复制和传输的实现方式。
103105
"""
104-
assert dp_size == 1
106+
assert dp_size_in_node == 1
105107

106108
move_token_indexes = []
107109
for task in move_tasks:
@@ -124,9 +126,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
124126
return move_buffer
125127

126128
def receive_from_prefill_node_p2p(
127-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
129+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
128130
):
129-
assert dp_size == 1
131+
assert dp_size_in_node == 1
130132

131133
move_token_indexes = []
132134
for task in move_tasks:

lightllm/common/mem_manager.py

Lines changed: 30 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
99
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1010
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
11-
from lightllm.utils.dist_utils import get_global_rank
11+
from lightllm.utils.dist_utils import get_current_rank_in_node
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1313

1414

@@ -37,8 +37,10 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3737
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
3838
from lightllm.utils.envs_utils import get_unique_server_name
3939

40-
rank_id = get_global_rank()
41-
self.shared_can_use_token_num = SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_id}")
40+
rank_in_node = get_current_rank_in_node()
41+
self.shared_can_use_token_num = SharedInt(
42+
f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}"
43+
)
4244

4345
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
4446
self._init_buffers(
@@ -83,13 +85,10 @@ def alloc_kv_move_buffer(self, max_req_total_len):
8385
self.kv_move_buf_indexes = torch.arange(0, max_req_total_len + 8, dtype=torch.int64, device="cuda")
8486
return
8587

86-
def send_to_decode_node(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
87-
"""
88-
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
89-
普通tp模式下, dp_size 一定等于 1, dp_index 一定等于 0, 同时普通模式下, 这两个参数并不会
90-
被真正使用
91-
"""
92-
assert dp_size == 1
88+
def send_to_decode_node(
89+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
90+
):
91+
assert dp_size_in_node == 1
9392

9493
# 先将数据发送到指定的一张卡上的buffer,再发送。
9594

@@ -123,14 +122,9 @@ def _get_kv_move_data(self, token_indexes: List[int], layer_index: int):
123122
return move_buffer
124123

125124
def receive_from_prefill_node(
126-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
125+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
127126
):
128-
"""
129-
dp_size 是为 deepseekv2 类型,可以 dp 和 tp 混合模式运行的模型定制的参数,
130-
普通tp模式下, dp_size 一定等于 1, 同时普通模式下, 这两个参数并不会
131-
被真正使用
132-
"""
133-
assert dp_size == 1
127+
assert dp_size_in_node == 1
134128

135129
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
136130

@@ -160,11 +154,13 @@ def _write_kv_move_data(self, token_indexes: torch.Tensor, buffer_tensor: torch.
160154
self.kv_buffer[layer_index : layer_index + 1, token_indexes, :, :] = buffer_tensor
161155
return
162156

163-
def send_to_decode_node_p2p(self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int):
157+
def send_to_decode_node_p2p(
158+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
159+
):
164160
"""
165161
使用 p2p triton kernel 进行数据复制和传输的实现方式。
166162
"""
167-
assert dp_size == 1
163+
assert dp_size_in_node == 1
168164

169165
# 先将数据发送到指定的一张卡上的buffer,再发送。
170166

@@ -190,9 +186,9 @@ def _get_kv_move_data_p2p(self, token_indexes: torch.Tensor, layer_index: int, k
190186
return move_buffer
191187

192188
def receive_from_prefill_node_p2p(
193-
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size: int
189+
self, move_tasks: List[KVMoveTask], mem_managers: List["MemoryManager"], dp_size_in_node: int
194190
):
195-
assert dp_size == 1
191+
assert dp_size_in_node == 1
196192

197193
# 先将数据接受到指定的一张卡上的buffer,再复制到其他的卡上。
198194

@@ -303,20 +299,16 @@ class ReadOnlyStaticsMemoryManager:
303299
def __init__(self) -> None:
304300
args = get_env_start_args()
305301
self.global_world_size = args.tp
306-
node_world_size = args.tp // args.nnodes
307-
rank_start = args.node_rank * node_world_size
308-
rank_end = (args.node_rank + 1) * node_world_size
309-
self.shared_tp_infos = {
310-
rank: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank}")
311-
for rank in range(rank_start, rank_end)
312-
}
313-
314-
def get_unrefed_token_num(self, dp_rank: int):
315-
args = get_env_start_args()
316-
if args.dp == 1 and args.nnodes > 1:
317-
# 兼容多机 dp size=1 的情况
318-
rank_id = args.tp // args.nnodes * args.node_rank
319-
return self.shared_tp_infos[rank_id].get_value()
320-
dp_size = args.dp
321-
dp_world_size = self.global_world_size // dp_size
322-
return self.shared_tp_infos[dp_rank * dp_world_size].get_value()
302+
self.node_world_size = args.tp // args.nnodes
303+
self.dp_world_size = self.global_world_size // args.dp
304+
# 兼容多机 dp size=1 纯 tp 模式的情况
305+
self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
306+
self.shared_tp_infos = [
307+
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
308+
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
309+
]
310+
311+
def get_unrefed_token_num(self, dp_rank_in_node: int):
312+
if self.is_multinode_tp:
313+
return self.shared_tp_infos[0].get_value()
314+
return self.shared_tp_infos[dp_rank_in_node].get_value()

lightllm/server/api_http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def set_args(self, args):
101101
enable_multimodal=args.enable_multimodal,
102102
metric_port=args.metric_port,
103103
)
104-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", args.dp)
104+
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
105+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
105106

106107

107108
g_objs = G_Objs()

lightllm/server/api_start.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,13 +165,13 @@ def normal_or_p_d_start(args):
165165
ports_locker = PortLocker(already_uesd_ports)
166166
ports_locker.lock_port()
167167

168+
node_world_size = args.tp // args.nnodes
168169
can_use_ports = alloc_can_use_network_port(
169-
num=6 + args.tp + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
170+
num=6 + node_world_size + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports
170171
)
171172
logger.info(f"alloced ports: {can_use_ports}")
172173
router_port, detokenization_port, detokenization_pub_port, visual_port, cache_port, metric_port = can_use_ports[0:6]
173-
model_rpc_ports = can_use_ports[6 : 6 + args.tp]
174-
can_use_ports = can_use_ports[6 + args.tp :]
174+
can_use_ports = can_use_ports[6:]
175175

176176
visual_model_tp_ports = []
177177
for _ in range(args.visual_dp):
@@ -188,7 +188,7 @@ def normal_or_p_d_start(args):
188188
args.metric_port = metric_port
189189

190190
# 申请在 p d 分离模式下,会用的端口
191-
args.pd_tp_infer_rpyc_ports = can_use_ports[0 : args.tp]
191+
args.pd_node_infer_rpyc_ports = can_use_ports[0:node_world_size]
192192
# p d 分离模式下用于标识节点的id
193193
args.pd_node_id = uuid.uuid4().int
194194
# p 节点用来建立torch kv 传输分布组的可用端口范围
@@ -231,7 +231,7 @@ def normal_or_p_d_start(args):
231231
process_manager.start_submodule_processes(
232232
start_funcs=[start_router_process, start_detokenization_process],
233233
start_args=[
234-
(args, router_port, detokenization_port, model_rpc_ports, metric_port),
234+
(args, router_port, detokenization_port, metric_port),
235235
(args, detokenization_port, detokenization_pub_port),
236236
],
237237
)

lightllm/server/router/batch.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88

99

1010
class Batch:
11-
def __init__(self, batch_id, reqs: List[Req], dp_size: int):
11+
def __init__(self, batch_id, reqs: List[Req], dp_size_in_node: int):
1212
self.batch_id = batch_id
1313
self.reqs = reqs
1414
self.id_to_reqs = {req.request_id: req for req in reqs}
15-
self.dp_size = dp_size
15+
self.dp_size_in_node = dp_size_in_node
1616
return
1717

1818
def input_tokens(self):
@@ -22,7 +22,7 @@ def input_tokens(self):
2222
return batch_input_tokens
2323

2424
def get_batch_decode_need_tokens(self):
25-
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size)] # for chunked prefill
25+
new_batch_decode_need_tokens = [0 for _ in range(self.dp_size_in_node)] # for chunked prefill
2626

2727
for req in self.reqs:
2828
req_dp_index = req.sample_params.suggested_dp_index

0 commit comments

Comments
 (0)