Skip to content

Commit 8ad585a

Browse files
author
Feiyang Wu
committed
modify dist_utils & remove child_ips
1 parent 3fd6e48 commit 8ad585a

File tree

21 files changed

+154
-87
lines changed

21 files changed

+154
-87
lines changed

lightllm/common/basemodel/layer_weights/base_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numpy as np
33
import threading
44
from lightllm.common.basemodel.layer_weights.meta_weights import BaseWeight
5-
from lightllm.utils.device_utils import get_current_device_id
5+
from lightllm.utils.dist_utils import get_current_device_id
66

77

88
class BaseLayerWeight:

lightllm/common/basemodel/layer_weights/hf_load_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
import gc
44
from safetensors import safe_open
55
import lightllm.utils.petrel_helper as utils
6+
from lightllm.utils.dist_utils import get_current_device_id
67

78

8-
def load_func(file_, local_tp_rank, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
9+
def load_func(file_, use_safetensors=False, pre_post_layer=None, transformer_layer_list=None, weight_dir=None):
910
# fix bug for 多线程加载的时候,每个线程内部的cuda device 会切回 0, 修改后来保证不会出现bug
1011
import torch.distributed as dist
1112

12-
# tp_rank = dist.get_rank()
13-
torch.cuda.set_device(local_tp_rank)
13+
torch.cuda.set_device(get_current_device_id())
1414

1515
if use_safetensors:
1616
weights = safe_open(os.path.join(weight_dir, file_), "pt", "cpu")
@@ -27,7 +27,7 @@ def load_func(file_, local_tp_rank, use_safetensors=False, pre_post_layer=None,
2727
gc.collect()
2828

2929

30-
def load_hf_weights(data_type, weight_dir, local_tp_rank, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
30+
def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_layer_list=None, weight_dict=None):
3131
if isinstance(data_type, str):
3232
data_type = torch.float16 if data_type == "fp16" else torch.float32
3333
if pre_post_layer is not None:
@@ -36,10 +36,10 @@ def load_hf_weights(data_type, weight_dir, local_tp_rank, pre_post_layer=None, t
3636
assert transformer_layer_list[0].data_type_ == data_type, "type is not right"
3737
if weight_dict:
3838
if pre_post_layer is not None:
39-
pre_post_layer.load_hf_weights(weight_dict, local_tp_rank)
39+
pre_post_layer.load_hf_weights(weight_dict)
4040
if transformer_layer_list is not None:
4141
for layer in transformer_layer_list:
42-
layer.load_hf_weights(weight_dict, local_tp_rank)
42+
layer.load_hf_weights(weight_dict)
4343
del weight_dict
4444
return
4545
use_safetensors = True
@@ -54,7 +54,6 @@ def load_hf_weights(data_type, weight_dir, local_tp_rank, pre_post_layer=None, t
5454

5555
partial_func = partial(
5656
load_func,
57-
local_tp_rank=local_tp_rank,
5857
use_safetensors=use_safetensors,
5958
pre_post_layer=pre_post_layer,
6059
transformer_layer_list=transformer_layer_list,

lightllm/common/basemodel/layer_weights/meta_weights/base_weight.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
import torch
22
from abc import ABC, abstractmethod
3-
from lightllm.utils.dist_utils import get_world_size, get_rank
4-
from lightllm.utils.device_utils import get_current_device_id
3+
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
54

65

76
class BaseWeight(ABC):
87
def __init__(self):
98
pass
109

1110
@abstractmethod
12-
def load_hf_weights(self, weights, local_tp_rank):
11+
def load_hf_weights(self, weights):
1312
pass
1413

1514
@abstractmethod
@@ -19,11 +18,11 @@ def verify_load(self):
1918

2019
class BaseWeightTpl(BaseWeight):
2120
def __init__(self):
22-
self.world_size_ = get_world_size()
23-
self.tp_rank_ = get_rank()
21+
self.world_size_ = get_global_world_size()
22+
self.tp_rank_ = get_global_rank()
2423
self.device_id_ = get_current_device_id()
2524

26-
def load_hf_weights(self, weights, local_tp_rank):
25+
def load_hf_weights(self, weights):
2726
pass
2827

2928
def verify_load(self):

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@
55
from .base_weight import BaseWeight
66
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
77
from lightllm.common.quantization.quantize_method import QuantizationMethod
8-
from lightllm.utils.dist_utils import get_world_size, get_rank
8+
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
99
from lightllm.common.vllm_kernel import _custom_ops as ops
10-
from lightllm.utils.device_utils import get_current_device_id
1110

1211

1312
class FusedMoeWeight(BaseWeight):
@@ -39,7 +38,7 @@ def __init__(
3938
self.n_routed_experts = n_routed_experts
4039
self.split_inter_size = split_inter_size
4140
self.data_type_ = data_type
42-
self.tp_rank_ = get_rank()
41+
self.tp_rank_ = get_global_rank()
4342
self.experts_up_projs = [None] * self.n_routed_experts
4443
self.experts_gate_projs = [None] * self.n_routed_experts
4544
self.experts_up_proj_scales = [None] * self.n_routed_experts
@@ -159,7 +158,7 @@ def _fuse_weight_scale(self):
159158
delattr(self, "experts_gate_proj_scales")
160159

161160
def _load_hf_weights_etp(self, weights):
162-
world_size_ = get_world_size()
161+
world_size_ = get_global_world_size()
163162
assert self.n_routed_experts % world_size_ == 0
164163
n_expert_ep = self.n_routed_experts // world_size_
165164

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from typing import Optional, Tuple, List, Dict, Any
55
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
66
from lightllm.common.quantization.quantize_method import QuantizationMethod
7-
from lightllm.utils.device_utils import get_current_device_id
7+
from lightllm.utils.dist_utils import get_current_device_id
88

99

1010
def generate_scale_name(name, weight_scale_suffix, act_scale_suffix):

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from .base_weight import BaseWeightTpl
3-
from lightllm.utils.device_utils import get_current_device_id
3+
from lightllm.utils.dist_utils import get_current_device_id
44

55

66
class NormWeight(BaseWeightTpl):

lightllm/common/mem_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt
99
from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory
1010
from lightllm.common.kv_trans_kernel.kv_trans import kv_trans
11-
from lightllm.utils.device_utils import get_current_device_id
11+
from lightllm.utils.dist_utils import get_current_device_id
1212

1313
logger = init_logger(__name__)
1414

lightllm/common/quantization/quantize_method.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
from abc import ABC, abstractmethod
3-
from lightllm.utils.device_utils import get_current_device_id
3+
from lightllm.utils.dist_utils import get_current_device_id
4+
45

56
class QuantizationMethod(ABC):
67
def __init__(self):

lightllm/models/deepseek2/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,6 @@ def _init_weights(self):
7676
]
7777
load_hf_weights(
7878
self.data_type,
79-
local_tp_rank=self.local_tp_rank,
8079
weight_dir=self.weight_dir_,
8180
pre_post_layer=self.pre_post_weight,
8281
transformer_layer_list=self.trans_layers_weight,

lightllm/models/vit/layer_weights/pre_and_post_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import torch.nn.functional as F
55
from lightllm.common.basemodel import PreAndPostLayerWeight
6-
from lightllm.utils.device_utils import get_current_device_id
6+
from lightllm.utils.dist_utils import get_current_device_id
77

88

99
class ViTPreAndPostLayerWeight(PreAndPostLayerWeight):

0 commit comments

Comments
 (0)