@@ -21,7 +21,7 @@ def _init_distributed_env(kvargs):
2121 set_dp_world_size (get_global_world_size () // get_dp_size ())
2222 set_current_dp_rank (get_global_rank () // get_dp_world_size ())
2323 set_current_dp_inner_rank (get_global_rank () % get_dp_world_size ())
24-
24+
2525 size_per_node = (kvargs ["world_size" ] + kvargs ["args" ].nnodes - 1 ) // kvargs ["args" ].nnodes
2626 local_tp_rank = kvargs ["rank_id" ] - size_per_node * kvargs ["args" ].node_rank
2727 set_current_device_id (local_tp_rank )
@@ -39,47 +39,61 @@ def _init_distributed_env(kvargs):
3939 dist .all_reduce (_a )
4040 del _a
4141
42+
4243def set_global_rank (global_rank : int ):
4344 set_environ ("GLOBAL_RANK" , global_rank )
4445
46+
4547def get_global_rank ():
4648 return int (get_environ ("GLOBAL_RANK" ))
4749
50+
4851def set_global_world_size (world_size : int ):
4952 set_environ ("GLOBAL_WORLD_SIZE" , world_size )
5053
54+
5155def get_global_world_size ():
5256 return int (get_environ ("GLOBAL_WORLD_SIZE" ))
5357
54- def set_dp_size (dp_size :int ):
58+
59+ def set_dp_size (dp_size : int ):
5560 """
5661 total dp num
5762 """
5863 set_environ ("LIGHTLLM_DP_SIZE" , dp_size )
59-
64+
65+
6066def get_dp_size ():
61- return get_environ ("LIGHTLLM_DP_SIZE" )
67+ return int (get_environ ("LIGHTLLM_DP_SIZE" ))
68+
6269
6370def set_dp_world_size (world_size : int ):
6471 set_environ ("DP_WORLD_SIZE" , world_size )
6572
73+
6674def get_dp_world_size ():
6775 return int (get_environ ("DP_WORLD_SIZE" ))
6876
77+
6978def set_current_dp_rank (rank : int ):
7079 set_environ ("CURRENT_DP_RANK" , rank )
7180
81+
7282def get_current_dp_rank ():
7383 return int (get_environ ("CURRENT_DP_RANK" ))
7484
85+
7586def set_current_dp_inner_rank (rank : int ):
7687 set_environ ("CURRENT_DP_INNER_RANK" , rank )
77-
88+
89+
7890def get_current_dp_inner_rank ():
7991 return get_environ ("CURRENT_DP_INNER_RANK" )
8092
93+
8194def set_current_device_id (device_id : int ):
8295 set_environ ("CURRENT_DEVICE_ID" , device_id )
8396
97+
8498def get_current_device_id ():
8599 return int (get_environ ("CURRENT_DEVICE_ID" ))
0 commit comments