Skip to content

Commit fd49ab3

Browse files
committed
完善rank信息的管理。
1 parent ce4d3eb commit fd49ab3

File tree

12 files changed

+201
-187
lines changed

12 files changed

+201
-187
lines changed

lightllm/common/mem_manager.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
99
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1010
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
11-
from lightllm.utils.dist_utils import get_global_rank
11+
from lightllm.utils.dist_utils import get_current_rank_in_node
1212
from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args
1313

1414

@@ -37,8 +37,8 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False
3737
# 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。
3838
from lightllm.utils.envs_utils import get_unique_server_name
3939

40-
rank_id = get_global_rank()
41-
self.shared_can_use_token_num = SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_id}")
40+
rank_in_node = get_current_rank_in_node()
41+
self.shared_can_use_token_num = SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
4242

4343
self.shared_can_use_token_num.set_value(self.can_use_mem_size)
4444
self._init_buffers(
@@ -303,20 +303,16 @@ class ReadOnlyStaticsMemoryManager:
303303
def __init__(self) -> None:
304304
args = get_env_start_args()
305305
self.global_world_size = args.tp
306-
node_world_size = args.tp // args.nnodes
307-
rank_start = args.node_rank * node_world_size
308-
rank_end = (args.node_rank + 1) * node_world_size
309-
self.shared_tp_infos = {
310-
rank: SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank}")
311-
for rank in range(rank_start, rank_end)
312-
}
313-
314-
def get_unrefed_token_num(self, dp_rank: int):
315-
args = get_env_start_args()
316-
if args.dp == 1 and args.nnodes > 1:
317-
# 兼容多机 dp size=1 的情况
318-
rank_id = args.tp // args.nnodes * args.node_rank
319-
return self.shared_tp_infos[rank_id].get_value()
320-
dp_size = args.dp
321-
dp_world_size = self.global_world_size // dp_size
322-
return self.shared_tp_infos[dp_rank * dp_world_size].get_value()
306+
self.node_world_size = args.tp // args.nnodes
307+
self.dp_world_size = self.global_world_size // args.dp
308+
# 兼容多机 dp size=1 纯 tp 模式的情况
309+
self.is_multinode_tp = args.dp == 1 and args.nnodes > 1
310+
self.shared_tp_infos = [
311+
SharedInt(f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}")
312+
for rank_in_node in range(0, self.node_world_size, self.dp_world_size)
313+
]
314+
315+
def get_unrefed_token_num(self, dp_rank_in_node: int):
316+
if self.is_multinode_tp:
317+
return self.shared_tp_infos[0].get_value()
318+
return self.shared_tp_infos[dp_rank_in_node].get_value()

lightllm/server/api_http.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def set_args(self, args):
101101
enable_multimodal=args.enable_multimodal,
102102
metric_port=args.metric_port,
103103
)
104-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", args.dp)
104+
dp_size_in_node = max(1, args.dp // args.nnodes) # 兼容多机存粹tp的运行模式,这时候 1 // 2 == 0, 需要兼容
105+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", dp_size_in_node)
105106

106107

107108
g_objs = G_Objs()

lightllm/server/router/dynamic_prompt/radix_cache.py

Lines changed: 15 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ class RadixCache:
9696
unique_name 主要用于解决单机,多实列部署时的shm冲突
9797
"""
9898

99-
def __init__(self, unique_name, total_token_num, tp_id, mem_manager: MemoryManager = None):
99+
def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager: MemoryManager = None):
100100
self.mem_manager = mem_manager
101101
self._key_dtype = torch.int64
102102
self._value_dtype = torch.int64
@@ -109,9 +109,9 @@ def __init__(self, unique_name, total_token_num, tp_id, mem_manager: MemoryManag
109109
self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) # 自定义比较器
110110
self.evict_tree_set.add(self.root_node)
111111

112-
self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{tp_id}", (1,), dtype=np.int64)
112+
self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64)
113113
self.refed_tokens_num.arr[0] = 0
114-
self.tree_total_tokens_num = SharedArray(f"{unique_name}_tree_total_tokens_num_{tp_id}", (1,), dtype=np.int64)
114+
self.tree_total_tokens_num = SharedArray(f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64)
115115
self.tree_total_tokens_num.arr[0] = 0
116116

117117
def insert(self, key, value=None):
@@ -345,9 +345,9 @@ class _RadixCacheReadOnlyClient:
345345
router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。
346346
"""
347347

348-
def __init__(self, unique_name, total_token_num, tp_id):
349-
self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{tp_id}", (1,), dtype=np.int64)
350-
self.tree_total_tokens_num = SharedArray(f"{unique_name}_tree_total_tokens_num_{tp_id}", (1,), dtype=np.int64)
348+
def __init__(self, unique_name, total_token_num, rank_in_node):
349+
self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64)
350+
self.tree_total_tokens_num = SharedArray(f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64)
351351

352352
def get_refed_tokens_num(self):
353353
return self.refed_tokens_num.arr[0]
@@ -360,115 +360,16 @@ def get_unrefed_tokens_num(self):
360360

361361

362362
class RadixCacheReadOnlyClient:
363-
def __init__(self, unique_name, total_token_num, tp_size):
364-
self.tp_clients: List[_RadixCacheReadOnlyClient] = [
365-
_RadixCacheReadOnlyClient(unique_name, total_token_num, tp_id) for tp_id in range(tp_size)
363+
def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size):
364+
self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [
365+
_RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node) for rank_in_node in range(0, node_world_size, dp_world_size)
366366
]
367367

368-
def get_refed_tokens_num(self, index):
369-
return self.tp_clients[index].get_refed_tokens_num()
368+
def get_refed_tokens_num(self, dp_rank_in_node):
369+
return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num()
370370

371-
def get_tree_total_tokens_num(self, index):
372-
return self.tp_clients[index].get_tree_total_tokens_num()
371+
def get_tree_total_tokens_num(self, dp_rank_in_node):
372+
return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num()
373373

374-
def get_unrefed_tokens_num(self, index):
375-
return self.tp_clients[index].get_unrefed_tokens_num()
376-
377-
378-
# ///////////////////////////////////////////////////////////////////////////////
379-
380-
if __name__ == "__main__":
381-
# test 1
382-
def test1():
383-
tree = RadixCache("unique_name", 100, 0)
384-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu"))
385-
assert ans == 0
386-
tree.print_self()
387-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu"))
388-
assert ans == 5
389-
tree.print_self()
390-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu"))
391-
assert ans == 8
392-
tree.print_self()
393-
394-
assert tree.get_refed_tokens_num() == 0
395-
assert tree.get_tree_total_tokens_num() == 13
396-
397-
# print("evict")
398-
tree.evict(9, lambda x: x)
399-
tree.print_self()
400-
assert tree.get_refed_tokens_num() == 0 and tree.get_tree_total_tokens_num() == 0
401-
402-
test1()
403-
404-
# test 2
405-
def test2():
406-
tree = RadixCache("unique_name", 100, 1)
407-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu"))
408-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu"))
409-
tree.print_self()
410-
411-
tree_node, size, values = tree.match_prefix(
412-
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64, device="cpu"), update_refs=False
413-
)
414-
assert tree_node.node_prefix_total_len == 5 and size == 5 and len(values) == 5
415-
tree_node, size, values = tree.match_prefix(
416-
torch.tensor([0, 1, 2, 3, 4, 9], dtype=torch.int64, device="cpu"), update_refs=False
417-
)
418-
assert tree_node.node_prefix_total_len == 5 and size == 5 and len(values) == 5
419-
tree_node, size, values = tree.match_prefix(
420-
torch.tensor([0, 1, 2, 3, 4, 7, 8], dtype=torch.int64, device="cpu"), update_refs=False
421-
)
422-
assert tree_node.node_prefix_total_len == 7 and size == 7 and len(values) == 7
423-
tree_node, size, values = tree.match_prefix(
424-
torch.tensor([0, 1, 2, 3, 4, 7, 9], dtype=torch.int64, device="cpu"), update_refs=False
425-
)
426-
assert tree_node.node_prefix_total_len == 6 and size == 6 and len(values) == 6
427-
print(ans)
428-
return
429-
430-
# test2()
431-
432-
# test 3
433-
def test3():
434-
tree = RadixCache("unique_name", 100, 2)
435-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu"))
436-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu"))
437-
tree.print_self()
438-
439-
tree_node, size, values = tree.match_prefix(
440-
torch.tensor([0, 1, 2, 3, 4], dtype=torch.int64, device="cpu"), update_refs=True
441-
)
442-
assert tree_node.node_prefix_total_len == 5 and size == 5 and len(values) == 5
443-
assert tree.get_refed_tokens_num() == 5 and tree.get_tree_total_tokens_num() == 13
444-
445-
tree_node, size, values = tree.match_prefix(
446-
torch.tensor([0, 1, 2, 3, 4, 7, 9], dtype=torch.int64, device="cpu"), update_refs=True
447-
)
448-
assert tree_node.node_prefix_total_len == 6 and size == 6 and len(values) == 6
449-
assert tree.get_refed_tokens_num() == 6 and tree.get_tree_total_tokens_num() == 13
450-
451-
tree.print_self()
452-
tree.evict(2, lambda x: x)
453-
assert tree.get_refed_tokens_num() == 6 and tree.get_tree_total_tokens_num() == 8
454-
tree.print_self()
455-
456-
tree.dec_node_ref_counter(tree_node)
457-
tree.print_self()
458-
print(ans)
459-
return
460-
461-
test3()
462-
463-
def test4():
464-
465-
tree = RadixCache("unique_name", 100, 2)
466-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int64, device="cpu"))
467-
ans = tree.insert(torch.tensor([0, 1, 2, 3, 4, 7, 8, 9], dtype=torch.int64, device="cpu"))
468-
tree.print_self()
469-
470-
tree.clear_tree_nodes()
471-
print(ans)
472-
return
473-
474-
test4()
374+
def get_unrefed_tokens_num(self, dp_rank_in_node):
375+
return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num()

lightllm/server/router/manager.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,15 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
4343
self.args = args
4444
self.model_weightdir = args.model_dir
4545
self.world_size = args.tp
46+
self.node_world_size = self.world_size // args.nnodes
4647
self.nnodes = args.nnodes
4748
self.node_rank = args.node_rank
4849
self.dp_size = args.dp
50+
# 兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
51+
self.dp_size_in_node = max(1, args.dp // self.nnodes)
52+
self.is_multinode_tp = args.nnodes > 1 and args.dp == 1
53+
# 判断是否是保守调度,保守调度不会发生暂停 req 的情况,但是有些场景可能影响吞吐
54+
self.is_safe_schedule = args.router_token_ratio == 0.0
4955
self.load_way = args.load_way
5056
self.mode = args.mode
5157
self.max_total_token_num = args.max_total_token_num
@@ -56,8 +62,8 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
5662
self.radix_cache_client = None
5763

5864
# 共享变量,用于存储router端调度分析得到的机器负载信息
59-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size)
60-
for dp_index in range(self.dp_size):
65+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
66+
for dp_index in range(self.dp_size_in_node):
6167
self.shared_token_load.set_estimated_peak_token_count(0, dp_index)
6268
self.shared_token_load.set_frozened_token_count(0, dp_index)
6369
self.shared_token_load.set_current_load(0.0, dp_index)
@@ -77,7 +83,7 @@ def __init__(self, args, router_port, detokenization_port, model_rpc_ports, metr
7783
self.send_to_detokenization.connect(f"{args.zmq_mode}127.0.0.1:{detokenization_port}")
7884
self.model_rpc_ports = model_rpc_ports
7985

80-
if args.nnodes > 1 and args.dp == 1:
86+
if self.is_multinode_tp:
8187
self.mulitnode_group = dist.init_process_group(
8288
backend="gloo",
8389
init_method=f"tcp://{args.nccl_host}:{args.multinode_router_gloo_port}",
@@ -177,9 +183,9 @@ async def wait_to_model_ready(self):
177183
self.args.max_total_token_num = self.max_total_token_num
178184
if self.args.use_dynamic_prompt_cache:
179185
self.radix_cache_client = RadixCacheReadOnlyClient(
180-
get_unique_server_name(), self.max_total_token_num, tp_size=self.world_size
186+
get_unique_server_name(), self.max_total_token_num, node_world_size=self.node_world_size, dp_world_size=self.world_size // self.dp_size
181187
)
182-
self.req_queue = build_req_queue(self.args, self, self.dp_size)
188+
self.req_queue = build_req_queue(self.args, self, self.dp_size_in_node)
183189
logger.info(f"use req queue {self.req_queue.__class__.__name__}")
184190

185191
if self.args.run_mode == "prefill":
@@ -223,7 +229,7 @@ async def loop_for_fwd(
223229
counter_count += 1
224230
if self.running_batch is not None:
225231
if counter_count % 50 == 0:
226-
for dp_index in range(self.dp_size):
232+
for dp_index in range(self.dp_size_in_node):
227233
token_ratio1 = self.get_used_tokens(dp_index) / self.max_total_token_num
228234
token_ratio2 = (
229235
self.max_total_token_num
@@ -244,7 +250,7 @@ async def loop_for_fwd(
244250
self.metric_client.gauge_set(
245251
"lightllm_batch_current_max_tokens",
246252
int(
247-
sum(self.shared_token_load.get_dynamic_max_load(d_i) for d_i in range(self.dp_size))
253+
sum(self.shared_token_load.get_dynamic_max_load(d_i) for d_i in range(self.dp_size_in_node))
248254
* self.max_total_token_num
249255
),
250256
)
@@ -264,7 +270,7 @@ async def get_schedule_result(self, running_batch: Batch):
264270

265271
def get_new_batch():
266272
limit_router_queue_length = None
267-
if self.nnodes > 1 and self.args.dp == 1:
273+
if self.is_multinode_tp:
268274
# 使用 all_reduce 获取最小值
269275
limit_router_queue_length = len(self.req_queue.waiting_req_list)
270276
limit_router_queue_length_tensor = torch.tensor(
@@ -381,7 +387,7 @@ def _can_decode(self, batch: Batch):
381387
# p d 分离模式下,目前只能使用保守调度,保证请求放入进行decode的时候
382388
# 显存token肯定是够用的。
383389
# deepseekv2 dp 模式下,采用保守调度,也肯定够用
384-
if self.is_pd_run_mode or self.dp_size > 1:
390+
if self.is_pd_run_mode or self.dp_size > 1 or self.is_safe_schedule:
385391
return True
386392

387393
# 下面的判定条件,只在 dp 为 1 的情况下启用

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@
3838
from lightllm.server.router.model_infer.infer_batch import InferReq, InferSamplingParams
3939
from lightllm.server.router.token_load import TokenLoad
4040
from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock
41-
from lightllm.utils.dist_utils import _init_distributed_env
41+
from lightllm.utils.dist_utils import init_distributed_env
4242
from lightllm.utils.envs_utils import get_unique_server_name
4343
from lightllm.server.core.objs import ShmReqManager
4444
from lightllm.server.router.model_infer.infer_batch import g_infer_context
4545
from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size
46-
from lightllm.utils.dist_utils import get_dp_world_size, get_current_dp_rank, get_current_rank_in_dp
46+
from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp
4747
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
48+
from lightllm.utils.dist_utils import get_dp_rank_in_node
4849
import torch.distributed as dist
4950

5051

@@ -64,6 +65,8 @@ def init_model(self, kvargs):
6465
self.tp_rank = kvargs["rank_id"]
6566
self.world_size = kvargs["world_size"]
6667
self.dp_size = kvargs.get("dp_size", 1)
68+
# dp_size_in_node 计算兼容多机纯tp的运行模式,这时候 1 // 2 == 0, 需要兼容
69+
self.dp_size_in_node = max(1, self.dp_size // self.nnodes)
6770
self.load_way = kvargs["load_way"]
6871
self.mode = kvargs["mode"]
6972
self.enable_chunked_prefill = kvargs.get("enable_chunked_prefill", False)
@@ -85,10 +88,10 @@ def init_model(self, kvargs):
8588
assert self.dp_size == self.world_size, "Currently only self-sustaining dp_size == tp_size"
8689
os.environ["ENABLE_DP"] = "1"
8790

88-
_init_distributed_env(kvargs)
91+
init_distributed_env(kvargs)
8992
self.init_rank_infos()
9093

91-
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size)
94+
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size_in_node)
9295

9396
from lightllm.distributed import custom_comm_ops
9497

@@ -239,17 +242,17 @@ def decode(self):
239242
raise NotImplementedError()
240243

241244
def pause_reqs(self, req_ids):
242-
if self.dp_size != 1:
245+
if self.dp_size_in_node != 1:
243246
req_ids = [req_id for req_id in req_ids if req_id in g_infer_context.requests_mapping]
244247

245248
g_infer_context.pause_reqs(req_ids)
246249
return
247250

248251
# 一些可以复用的单元功能函数
249252
def _init_reqs(self, reqs: List[Tuple], init_req_obj=True):
250-
if self.dp_size != 1:
251-
cur_dp_index = self.tp_rank
252-
reqs = [req for req in reqs if req[3] == cur_dp_index]
253+
if self.dp_size_in_node != 1:
254+
dp_rank_in_node = self.dp_rank_in_node
255+
reqs = [req for req in reqs if req[3] == dp_rank_in_node]
253256

254257
g_infer_state_lock.acquire()
255258
g_infer_context.add_reqs(reqs, init_req_obj=init_req_obj)
@@ -280,7 +283,8 @@ def init_rank_infos(self):
280283
self.rank_in_node = get_current_rank_in_node()
281284
self.current_device_id = get_current_device_id()
282285
self.rank_in_dp = get_current_rank_in_dp()
283-
self.dp_rank = get_current_dp_rank()
286+
self.global_dp_rank = get_global_dp_rank()
287+
self.dp_rank_in_node = get_dp_rank_in_node()
284288
self.dp_world_size = get_dp_world_size()
285289
self.global_rank = get_global_rank()
286290
self.global_world_size = get_global_world_size()

lightllm/server/router/req_queue/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .dp_base_queue import DpQueue
66

77

8-
def build_req_queue(args, router, dp_size: int):
8+
def build_req_queue(args, router, dp_size_in_node: int):
99
queue_class = None
1010
if args.run_mode == "decode":
1111
queue_class = ContinuesBatchQueueForPDDecode
@@ -22,7 +22,7 @@ def build_req_queue(args, router, dp_size: int):
2222
if queue_class is None:
2323
queue_class = ContinuesBatchQueue
2424

25-
if dp_size == 1:
26-
return queue_class(args, router, 0, dp_size)
25+
if dp_size_in_node == 1:
26+
return queue_class(args, router, 0, dp_size_in_node)
2727
else:
28-
return DpQueue(args, router, queue_class, dp_size)
28+
return DpQueue(args, router, queue_class, dp_size_in_node)

0 commit comments

Comments
 (0)