@@ -22,11 +22,11 @@ def _init_distributed_env(kvargs):
2222 set_current_dp_rank (get_global_rank () // get_dp_world_size ())
2323 set_current_dp_inner_rank (get_global_rank () % get_dp_world_size ())
2424
25- size_per_node = ( kvargs ["world_size" ] + kvargs ["args" ].nnodes - 1 ) // kvargs [ "args" ]. nnodes
26- local_tp_rank = kvargs ["rank_id " ] - size_per_node * kvargs ["args" ].node_rank
27- set_current_device_id ( local_tp_rank )
28- torch . cuda . set_device ( local_tp_rank )
29- print ( local_tp_rank )
25+ assert kvargs ["world_size" ] % kvargs ["args" ].nnodes == 0 , "world_size should be divided by nnodes"
26+ size_per_node = kvargs ["world_size " ] // kvargs ["args" ].nnodes
27+ device_id = kvargs [ "rank_id" ] % size_per_node
28+ set_current_device_id ( device_id )
29+ torch . cuda . set_device ( device_id )
3030 if kvargs ["world_size" ] > 1 :
3131 dist .init_process_group (
3232 "nccl" ,
@@ -35,25 +35,25 @@ def _init_distributed_env(kvargs):
3535 world_size = kvargs ["world_size" ],
3636 )
3737 # warmup nccl communicator
38- _a = torch .zeros ([1 ]).to (f"cuda:{ local_tp_rank } " )
38+ _a = torch .zeros ([1 ]).to (f"cuda:{ device_id } " )
3939 dist .all_reduce (_a )
4040 del _a
4141
4242
4343def set_global_rank (global_rank : int ):
44- set_environ ("GLOBAL_RANK " , global_rank )
44+ set_environ ("LIGHTLLM_GLOBAL_RANK " , global_rank )
4545
4646
4747def get_global_rank ():
48- return int (get_environ ("GLOBAL_RANK " ))
48+ return int (get_environ ("LIGHTLLM_GLOBAL_RANK " ))
4949
5050
5151def set_global_world_size (world_size : int ):
52- set_environ ("GLOBAL_WORLD_SIZE " , world_size )
52+ set_environ ("LIGHTLLM_GLOBAL_WORLD_SIZE " , world_size )
5353
5454
5555def get_global_world_size ():
56- return int (get_environ ("GLOBAL_WORLD_SIZE " ))
56+ return int (get_environ ("LIGHTLLM_GLOBAL_WORLD_SIZE " ))
5757
5858
5959def set_dp_size (dp_size : int ):
@@ -68,32 +68,32 @@ def get_dp_size():
6868
6969
7070def set_dp_world_size (world_size : int ):
71- set_environ ("DP_WORLD_SIZE " , world_size )
71+ set_environ ("LIGHTLLM_DP_WORLD_SIZE " , world_size )
7272
7373
7474def get_dp_world_size ():
75- return int (get_environ ("DP_WORLD_SIZE " ))
75+ return int (get_environ ("LIGHTLLM_DP_WORLD_SIZE " ))
7676
7777
7878def set_current_dp_rank (rank : int ):
79- set_environ ("CURRENT_DP_RANK " , rank )
79+ set_environ ("LIGHTLLM_CURRENT_DP_RANK " , rank )
8080
8181
8282def get_current_dp_rank ():
83- return int (get_environ ("CURRENT_DP_RANK " ))
83+ return int (get_environ ("LIGHTLLM_CURRENT_DP_RANK " ))
8484
8585
8686def set_current_dp_inner_rank (rank : int ):
87- set_environ ("CURRENT_DP_INNER_RANK " , rank )
87+ set_environ ("LIGHTLLM_CURRENT_DP_INNER_RANK " , rank )
8888
8989
9090def get_current_dp_inner_rank ():
91- return get_environ ("CURRENT_DP_INNER_RANK " )
91+ return get_environ ("LIGHTLLM_CURRENT_DP_INNER_RANK " )
9292
9393
9494def set_current_device_id (device_id : int ):
95- set_environ ("CURRENT_DEVICE_ID " , device_id )
95+ set_environ ("LIGHTLLM_CURRENT_DEVICE_ID " , device_id )
9696
9797
9898def get_current_device_id ():
99- return int (get_environ ("CURRENT_DEVICE_ID " ))
99+ return int (get_environ ("LIGHTLLM_CURRENT_DEVICE_ID " ))
0 commit comments