Skip to content

Commit 72f4eb3

Browse files
authored
fix
1 parent 46dbbb2 commit 72f4eb3

File tree

1 file changed

+28
-2
lines changed

1 file changed

+28
-2
lines changed

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@
4242
from lightllm.utils.envs_utils import get_unique_server_name
4343
from lightllm.server.core.objs import ShmReqManager
4444
from lightllm.server.router.model_infer.infer_batch import g_infer_context
45+
from lightllm.utils.dist_utils import get_global_rank, get_global_world_size, get_dp_size
46+
from lightllm.utils.dist_utils import get_dp_world_size, get_current_dp_rank, get_current_rank_in_dp
47+
from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size
4548
import torch.distributed as dist
4649

4750

@@ -82,9 +85,8 @@ def init_model(self, kvargs):
8285
assert self.dp_size == self.world_size, "Currently only self-sustaining dp_size == tp_size"
8386
os.environ["ENABLE_DP"] = "1"
8487

85-
size_per_node = (self.world_size + self.nnodes - 1) // self.nnodes
86-
self.local_tp_rank = self.tp_rank - size_per_node * self.node_rank
8788
_init_distributed_env(kvargs)
89+
self.init_rank_infos()
8890

8991
self.shared_token_load = TokenLoad(f"{get_unique_server_name()}_shared_token_load", self.dp_size)
9092

@@ -273,3 +275,27 @@ def preload_prompt_cache_kv_buffer(self, model_cfg):
273275
self.radix_cache.match_prefix(
274276
torch.tensor(model_cfg["prompt_cache_token_ids"], dtype=torch.int64, device="cpu"), update_refs=True
275277
)
278+
279+
def init_rank_infos(self):
280+
self.node_world_size = get_node_world_size()
281+
self.rank_in_node = get_current_rank_in_node()
282+
self.current_device_id = get_current_device_id()
283+
self.rank_in_dp = get_current_rank_in_dp()
284+
self.dp_rank = get_current_dp_rank()
285+
self.dp_world_size = get_dp_world_size()
286+
self.global_rank = get_global_rank()
287+
self.global_world_size = get_global_world_size()
288+
self.dp_size = get_dp_size()
289+
290+
if self.nnodes > 1 and self.dp_size == 1:
291+
if self.rank_in_node == 0:
292+
self.is_master_in_dp = True
293+
else:
294+
self.is_master_in_dp = False
295+
else:
296+
if self.rank_in_dp == 0:
297+
self.is_master_in_dp = True
298+
else:
299+
self.is_master_in_dp = False
300+
return
301+

0 commit comments

Comments
 (0)