22import os
33import torch
44
5+ # 规范 rank 的含义,在 llm 推理的相关代码中下述的 rank 的含义如下:
6+ # global_rank 全局 rank 序列id, 如两节点 8卡,会存在 0 - 15 16个global_rank
7+ # global_world_size 全局的 world size 大小, 如两节点 8 卡,该值为 16
8+ # dp_size 如果部署形态是一个推理实列包含几个数据并行的推理实列,则 dp size 为整个系统中的 dp 并行数量
9+ # dp_world_size 每一个dp 数据并行占用的卡数
10+ # dp_rank 指每个dp 数据并行在整个推理实列中dp的rank号, 如果 16卡部署,4 dp size, 则存在 0 - 3 4个dp_rank
11+ # 值,其中 0-3号卡为 dp_rank 0, 4-8 为 dp_rank 1, 9-12 为dp_rank 2, 13-15为dp_rank 3
12+ # rank_in_dp 指在一个dp内的rank序号。
13+ # node_world_size 指一个推理节点的使用的卡数,如两机 tp 推理,如果两机器8卡,则 node_world_size 为 8.
14+ # rank_in_node 指在一个node内的rank序号,如两机8卡推理,每机上的rank序号都是0-8
515
616def set_environ (environ_name , value ):
717 os .environ [environ_name ] = str (value )
@@ -15,15 +25,19 @@ def get_environ(environ_name):
1525
1626
1727def _init_distributed_env (kvargs ):
28+ assert kvargs ["world_size" ] % kvargs ["args" ].nnodes == 0 , "world_size should be divided by nnodes"
29+ node_world_size = kvargs ["world_size" ] // kvargs ["args" ].nnodes
30+
1831 set_global_rank (kvargs ["rank_id" ])
1932 set_global_world_size (kvargs ["world_size" ])
2033 set_dp_size (kvargs .get ("dp_size" , 1 ))
2134 set_dp_world_size (get_global_world_size () // get_dp_size ())
2235 set_current_dp_rank (get_global_rank () // get_dp_world_size ())
23- set_current_dp_inner_rank (get_global_rank () % get_dp_world_size ())
36+ set_current_rank_in_dp (get_global_rank () % get_dp_world_size ())
37+ set_current_rank_in_node (get_global_rank () % node_world_size )
38+ set_node_world_size (node_world_size )
39+
2440
25- assert kvargs ["world_size" ] % kvargs ["args" ].nnodes == 0 , "world_size should be divided by nnodes"
26- size_per_node = kvargs ["world_size" ] // kvargs ["args" ].nnodes
2741 device_id = kvargs ["rank_id" ] % size_per_node
2842 set_current_device_id (device_id )
2943 torch .cuda .set_device (device_id )
@@ -83,12 +97,12 @@ def get_current_dp_rank():
8397 return int (get_environ ("LIGHTLLM_CURRENT_DP_RANK" ))
8498
8599
86- def set_current_dp_inner_rank (rank : int ):
87- set_environ ("LIGHTLLM_CURRENT_DP_INNER_RANK " , rank )
100+ def set_current_rank_in_dp (rank : int ):
101+ set_environ ("LIGHTLLM_CURRENT_RANK_IN_DP " , rank )
88102
89103
90- def get_current_dp_inner_rank ():
91- return get_environ ("LIGHTLLM_CURRENT_DP_INNER_RANK" )
104+ def get_current_rank_in_dp ():
105+ return int ( get_environ ("LIGHTLLM_CURRENT_RANK_IN_DP" ) )
92106
93107
94108def set_current_device_id (device_id : int ):
@@ -97,3 +111,19 @@ def set_current_device_id(device_id: int):
97111
98112def get_current_device_id ():
99113 return int (get_environ ("LIGHTLLM_CURRENT_DEVICE_ID" ))
114+
115+
116+ def set_current_rank_in_node (rank :int ):
117+ set_environ ("LIGHTLLM_CURRENT_RANK_IN_NODE" , rank )
118+
119+
120+ def get_current_rank_in_node ():
121+ return int (get_environ ("LIGHTLLM_CURRENT_RANK_IN_NODE" ))
122+
123+
124+ def set_node_world_size (node_world_size : int ):
125+ set_environ ("LIGHTLLM_NODE_WORLD_SIZE" , node_world_size )
126+
127+
128+ def get_node_world_size ():
129+ return int (get_environ ("LIGHTLLM_NODE_WORLD_SIZE" ))
0 commit comments