Skip to content

Commit 82a756f

Browse files
committed
fix dist_utils
1 parent ecd495c commit 82a756f

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

lightllm/utils/dist_utils.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,11 @@ def _init_distributed_env(kvargs):
2222
set_current_dp_rank(get_global_rank() // get_dp_world_size())
2323
set_current_dp_inner_rank(get_global_rank() % get_dp_world_size())
2424

25-
size_per_node = (kvargs["world_size"] + kvargs["args"].nnodes - 1) // kvargs["args"].nnodes
26-
local_tp_rank = kvargs["rank_id"] - size_per_node * kvargs["args"].node_rank
27-
set_current_device_id(local_tp_rank)
28-
torch.cuda.set_device(local_tp_rank)
29-
print(local_tp_rank)
25+
assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes"
26+
size_per_node = kvargs["world_size"] // kvargs["args"].nnodes
27+
device_id = kvargs["rank_id"] % size_per_node
28+
set_current_device_id(device_id)
29+
torch.cuda.set_device(device_id)
3030
if kvargs["world_size"] > 1:
3131
dist.init_process_group(
3232
"nccl",
@@ -35,25 +35,25 @@ def _init_distributed_env(kvargs):
3535
world_size=kvargs["world_size"],
3636
)
3737
# warmup nccl communicator
38-
_a = torch.zeros([1]).to(f"cuda:{local_tp_rank}")
38+
_a = torch.zeros([1]).to(f"cuda:{device_id}")
3939
dist.all_reduce(_a)
4040
del _a
4141

4242

4343
def set_global_rank(global_rank: int):
44-
set_environ("GLOBAL_RANK", global_rank)
44+
set_environ("LIGHTLLM_GLOBAL_RANK", global_rank)
4545

4646

4747
def get_global_rank():
48-
return int(get_environ("GLOBAL_RANK"))
48+
return int(get_environ("LIGHTLLM_GLOBAL_RANK"))
4949

5050

5151
def set_global_world_size(world_size: int):
52-
set_environ("GLOBAL_WORLD_SIZE", world_size)
52+
set_environ("LIGHTLLM_GLOBAL_WORLD_SIZE", world_size)
5353

5454

5555
def get_global_world_size():
56-
return int(get_environ("GLOBAL_WORLD_SIZE"))
56+
return int(get_environ("LIGHTLLM_GLOBAL_WORLD_SIZE"))
5757

5858

5959
def set_dp_size(dp_size: int):
@@ -68,32 +68,32 @@ def get_dp_size():
6868

6969

7070
def set_dp_world_size(world_size: int):
71-
set_environ("DP_WORLD_SIZE", world_size)
71+
set_environ("LIGHTLLM_DP_WORLD_SIZE", world_size)
7272

7373

7474
def get_dp_world_size():
75-
return int(get_environ("DP_WORLD_SIZE"))
75+
return int(get_environ("LIGHTLLM_DP_WORLD_SIZE"))
7676

7777

7878
def set_current_dp_rank(rank: int):
79-
set_environ("CURRENT_DP_RANK", rank)
79+
set_environ("LIGHTLLM_CURRENT_DP_RANK", rank)
8080

8181

8282
def get_current_dp_rank():
83-
return int(get_environ("CURRENT_DP_RANK"))
83+
return int(get_environ("LIGHTLLM_CURRENT_DP_RANK"))
8484

8585

8686
def set_current_dp_inner_rank(rank: int):
87-
set_environ("CURRENT_DP_INNER_RANK", rank)
87+
set_environ("LIGHTLLM_CURRENT_DP_INNER_RANK", rank)
8888

8989

9090
def get_current_dp_inner_rank():
91-
return get_environ("CURRENT_DP_INNER_RANK")
91+
return get_environ("LIGHTLLM_CURRENT_DP_INNER_RANK")
9292

9393

9494
def set_current_device_id(device_id: int):
95-
set_environ("CURRENT_DEVICE_ID", device_id)
95+
set_environ("LIGHTLLM_CURRENT_DEVICE_ID", device_id)
9696

9797

9898
def get_current_device_id():
99-
return int(get_environ("CURRENT_DEVICE_ID"))
99+
return int(get_environ("LIGHTLLM_CURRENT_DEVICE_ID"))

0 commit comments

Comments
 (0)