Skip to content

Commit 8ff7ed6

Browse files
committed
fix visualserver
1 parent 877b98f commit 8ff7ed6

File tree

2 files changed

+22
-12
lines changed

2 files changed

+22
-12
lines changed

lightllm/server/visualserver/model_infer/model_rpc.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
1717
from lightllm.utils.infer_utils import set_random_seed
1818
from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end
19-
from lightllm.utils.dist_utils import set_current_device_id
19+
from lightllm.utils.dist_utils import _init_vision_distributed_env
2020
from lightllm.utils.graceful_utils import graceful_registry
2121

2222

@@ -31,20 +31,11 @@ def exposed_init_model(self, kvargs):
3131
self.tp_rank_id = kvargs["tp_rank_id"]
3232
self.cache_port = kvargs["cache_port"]
3333
weight_dir = kvargs["weight_dir"]
34-
visual_gpu_ids = kvargs["visual_gpu_ids"]
35-
visual_nccl_port = kvargs["visual_nccl_port"]
3634
self.vit_rank_id = kvargs["vit_rank_id"]
3735
self.cache_client = rpyc.connect("localhost", self.cache_port)
3836
self.data_type = kvargs["data_type"]
3937

40-
torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id])
41-
set_current_device_id(visual_gpu_ids[self.vit_rank_id])
42-
dist.init_process_group(
43-
backend="nccl",
44-
init_method=f"tcp://127.0.0.1:{visual_nccl_port}",
45-
rank=self.tp_rank_id,
46-
world_size=self.vit_tp,
47-
)
38+
_init_vision_distributed_env(kvargs)
4839
model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir)
4940

5041
try:

lightllm/utils/dist_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,26 @@ def get_environ(environ_name):
2525
return value
2626

2727

28+
def _init_vision_distributed_env(kvargs):
29+
world_size = kvargs["vit_tp"]
30+
set_global_rank(kvargs["tp_rank_id"])
31+
set_global_world_size(world_size)
32+
visual_gpu_ids = kvargs["visual_gpu_ids"]
33+
device_id = visual_gpu_ids[kvargs["vit_rank_id"]]
34+
set_current_device_id(device_id)
35+
torch.cuda.set_device(device_id)
36+
dist.init_process_group(
37+
"nccl",
38+
init_method=f'tcp://127.0.0.1:{kvargs["visual_nccl_port"]}',
39+
rank=kvargs["tp_rank_id"],
40+
world_size=world_size,
41+
)
42+
# warmup nccl communicator
43+
_a = torch.zeros([1]).to(f"cuda:{device_id}")
44+
dist.all_reduce(_a)
45+
del _a
46+
47+
2848
def _init_distributed_env(kvargs):
2949
assert kvargs["world_size"] % kvargs["args"].nnodes == 0, "world_size should be divided by nnodes"
3050
node_world_size = kvargs["world_size"] // kvargs["args"].nnodes
@@ -47,7 +67,6 @@ def _init_distributed_env(kvargs):
4767
rank=kvargs["rank_id"],
4868
world_size=kvargs["world_size"],
4969
)
50-
# if kvargs["world_size"] > 1:
5170
# warmup nccl communicator
5271
_a = torch.zeros([1]).to(f"cuda:{device_id}")
5372
dist.all_reduce(_a)

0 commit comments

Comments
 (0)