Skip to content

Commit 958b83d

Browse files
committed
fix get_dp_size
1 parent b55edd7 commit 958b83d

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

lightllm/utils/dist_utils.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def _init_distributed_env(kvargs):
2121
set_dp_world_size(get_global_world_size() // get_dp_size())
2222
set_current_dp_rank(get_global_rank() // get_dp_world_size())
2323
set_current_dp_inner_rank(get_global_rank() % get_dp_world_size())
24-
24+
2525
size_per_node = (kvargs["world_size"] + kvargs["args"].nnodes - 1) // kvargs["args"].nnodes
2626
local_tp_rank = kvargs["rank_id"] - size_per_node * kvargs["args"].node_rank
2727
set_current_device_id(local_tp_rank)
@@ -39,47 +39,61 @@ def _init_distributed_env(kvargs):
3939
dist.all_reduce(_a)
4040
del _a
4141

42+
4243
def set_global_rank(global_rank: int):
4344
set_environ("GLOBAL_RANK", global_rank)
4445

46+
4547
def get_global_rank():
4648
return int(get_environ("GLOBAL_RANK"))
4749

50+
4851
def set_global_world_size(world_size: int):
4952
set_environ("GLOBAL_WORLD_SIZE", world_size)
5053

54+
5155
def get_global_world_size():
5256
return int(get_environ("GLOBAL_WORLD_SIZE"))
5357

54-
def set_dp_size(dp_size:int):
58+
59+
def set_dp_size(dp_size: int):
5560
"""
5661
total dp num
5762
"""
5863
set_environ("LIGHTLLM_DP_SIZE", dp_size)
59-
64+
65+
6066
def get_dp_size():
61-
return get_environ("LIGHTLLM_DP_SIZE")
67+
return int(get_environ("LIGHTLLM_DP_SIZE"))
68+
6269

6370
def set_dp_world_size(world_size: int):
6471
set_environ("DP_WORLD_SIZE", world_size)
6572

73+
6674
def get_dp_world_size():
6775
return int(get_environ("DP_WORLD_SIZE"))
6876

77+
6978
def set_current_dp_rank(rank: int):
7079
set_environ("CURRENT_DP_RANK", rank)
7180

81+
7282
def get_current_dp_rank():
7383
return int(get_environ("CURRENT_DP_RANK"))
7484

85+
7586
def set_current_dp_inner_rank(rank: int):
7687
set_environ("CURRENT_DP_INNER_RANK", rank)
77-
88+
89+
7890
def get_current_dp_inner_rank():
7991
return get_environ("CURRENT_DP_INNER_RANK")
8092

93+
8194
def set_current_device_id(device_id: int):
8295
set_environ("CURRENT_DEVICE_ID", device_id)
8396

97+
8498
def get_current_device_id():
8599
return int(get_environ("CURRENT_DEVICE_ID"))

0 commit comments

Comments
 (0)