1616from lightllm .server .embed_cache .utils import tensor2bytes , read_shm , create_shm , get_shm_name_data , get_shm_name_embed
1717from lightllm .utils .infer_utils import set_random_seed
1818from lightllm .utils .infer_utils import calculate_time , mark_start , mark_end
19- from lightllm .utils .dist_utils import set_current_device_id
19+ from lightllm .utils .dist_utils import _init_vision_distributed_env
2020from lightllm .utils .graceful_utils import graceful_registry
2121
2222
@@ -31,20 +31,11 @@ def exposed_init_model(self, kvargs):
3131 self .tp_rank_id = kvargs ["tp_rank_id" ]
3232 self .cache_port = kvargs ["cache_port" ]
3333 weight_dir = kvargs ["weight_dir" ]
34- visual_gpu_ids = kvargs ["visual_gpu_ids" ]
35- visual_nccl_port = kvargs ["visual_nccl_port" ]
3634 self .vit_rank_id = kvargs ["vit_rank_id" ]
3735 self .cache_client = rpyc .connect ("localhost" , self .cache_port )
3836 self .data_type = kvargs ["data_type" ]
3937
40- torch .cuda .set_device (visual_gpu_ids [self .vit_rank_id ])
41- set_current_device_id (visual_gpu_ids [self .vit_rank_id ])
42- dist .init_process_group (
43- backend = "nccl" ,
44- init_method = f"tcp://127.0.0.1:{ visual_nccl_port } " ,
45- rank = self .tp_rank_id ,
46- world_size = self .vit_tp ,
47- )
38+ _init_vision_distributed_env (kvargs )
4839 model_cfg , _ = PretrainedConfig .get_config_dict (weight_dir )
4940
5041 try :
0 commit comments