Skip to content

Commit 7061bfb

Browse files
authored
fix
1 parent 2afd14b commit 7061bfb

File tree

1 file changed

+37
-7
lines changed

1 file changed

+37
-7
lines changed

lightllm/utils/dist_utils.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,16 @@
22
import os
33
import torch
44

5+
# 规范 rank 的含义,在 llm 推理的相关代码中下述的 rank 的含义如下:
6+
# global_rank 全局 rank 序列id, 如两节点 8卡,会存在 0 - 15 16个global_rank
7+
# global_world_size 全局的 world size 大小, 如两节点 8 卡,该值为 16
8+
# dp_size 如果部署形态是一个推理实列包含几个数据并行的推理实列,则 dp size 为整个系统中的 dp 并行数量
9+
# dp_world_size 每一个dp 数据并行占用的卡数
10+
# dp_rank 指每个dp 数据并行在整个推理实列中dp的rank号, 如果 16卡部署,4 dp size, 则存在 0 - 3 4个dp_rank
11+
# 值,其中 0-3号卡为 dp_rank 0, 4-8 为 dp_rank 1, 9-12 为dp_rank 2, 13-15为dp_rank 3
12+
# rank_in_dp 指在一个dp内的rank序号。
13+
# node_world_size 指一个推理节点的使用的卡数,如两机 tp 推理,如果两机器8卡,则 node_world_size 为 8.
14+
# rank_in_node 指在一个node内的rank序号,如两机8卡推理,每机上的rank序号都是0-8
515

616
def set_environ(environ_name, value):
717
os.environ[environ_name] = str(value)
@@ -15,15 +25,19 @@ def get_environ(environ_name):
1525

1626

1727
def _init_distributed_env(kvargs):
28+
assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes"
29+
node_world_size = kvargs["world_size"] // kvargs["args"].nnodes
30+
1831
set_global_rank(kvargs["rank_id"])
1932
set_global_world_size(kvargs["world_size"])
2033
set_dp_size(kvargs.get("dp_size", 1))
2134
set_dp_world_size(get_global_world_size() // get_dp_size())
2235
set_current_dp_rank(get_global_rank() // get_dp_world_size())
23-
set_current_dp_inner_rank(get_global_rank() % get_dp_world_size())
36+
set_current_rank_in_dp(get_global_rank() % get_dp_world_size())
37+
set_current_rank_in_node(get_global_rank() % node_world_size)
38+
set_node_world_size(node_world_size)
39+
2440

25-
assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes"
26-
size_per_node = kvargs["world_size"] // kvargs["args"].nnodes
2741
device_id = kvargs["rank_id"] % size_per_node
2842
set_current_device_id(device_id)
2943
torch.cuda.set_device(device_id)
@@ -83,12 +97,12 @@ def get_current_dp_rank():
8397
return int(get_environ("LIGHTLLM_CURRENT_DP_RANK"))
8498

8599

86-
def set_current_dp_inner_rank(rank: int):
87-
set_environ("LIGHTLLM_CURRENT_DP_INNER_RANK", rank)
100+
def set_current_rank_in_dp(rank: int):
101+
set_environ("LIGHTLLM_CURRENT_RANK_IN_DP", rank)
88102

89103

90-
def get_current_dp_inner_rank():
91-
return get_environ("LIGHTLLM_CURRENT_DP_INNER_RANK")
104+
def get_current_rank_in_dp():
105+
return int(get_environ("LIGHTLLM_CURRENT_RANK_IN_DP"))
92106

93107

94108
def set_current_device_id(device_id: int):
@@ -97,3 +111,19 @@ def set_current_device_id(device_id: int):
97111

98112
def get_current_device_id():
99113
return int(get_environ("LIGHTLLM_CURRENT_DEVICE_ID"))
114+
115+
116+
def set_current_rank_in_node(rank:int):
117+
set_environ("LIGHTLLM_CURRENT_RANK_IN_NODE", rank)
118+
119+
120+
def get_current_rank_in_node():
121+
return int(get_environ("LIGHTLLM_CURRENT_RANK_IN_NODE"))
122+
123+
124+
def set_node_world_size(node_world_size: int):
125+
set_environ("LIGHTLLM_NODE_WORLD_SIZE", node_world_size)
126+
127+
128+
def get_node_world_size():
129+
return int(get_environ("LIGHTLLM_NODE_WORLD_SIZE"))

0 commit comments

Comments
 (0)