Skip to content

Commit 7c03dde

Browse files
committed
add auto redundancy base function.
1 parent 13a6bb7 commit 7c03dde

File tree

2 files changed

+168
-0
lines changed

2 files changed

+168
-0
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def __init__(
4343
) -> None:
4444
super().__init__()
4545

46+
self.layer_num = layer_num
4647
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
4748
self.quantized_weight = quant_cfg.quantized_weight
4849
if self.quant_method is not None:
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
import numpy as np
2+
import torch
3+
from .fused_moe_weight_ep import FusedMoeWeightEP
4+
from lightllm.utils.log_utils import init_logger
5+
from typing import Dict
6+
7+
logger = init_logger(__name__)
8+
9+
10+
class FusedMoeWeightEPAutoRedundancy:
11+
def __init__(
12+
self,
13+
ep_fused_moe_weight: FusedMoeWeightEP,
14+
) -> None:
15+
super().__init__()
16+
self._ep_w = ep_fused_moe_weight
17+
self.redundancy_expert_num = self._ep_w.redundancy_expert_num
18+
19+
def prepare_redundancy_experts(
20+
self,
21+
):
22+
expert_counter = self._ep_w.routed_expert_counter_tensor.detach().cpu().numpy()
23+
logger.info(
24+
f"layer_index {self._ep_w.layer_num} global_rank {self._ep_w.global_rank_} expert_counter: {expert_counter}"
25+
)
26+
self._ep_w.routed_expert_counter_tensor.fill_(0)
27+
28+
start_expert_id = self._ep_w.ep_n_routed_experts * self._ep_w.global_rank_
29+
no_redundancy_expert_ids = list(range(start_expert_id, start_expert_id + self._ep_w.ep_n_routed_experts))
30+
# 不要选中当前已经存在的非冗余专家作为冗余专家
31+
expert_counter[no_redundancy_expert_ids] = 0
32+
33+
self.redundancy_expert_ids = list(np.argsort(expert_counter)[-self.redundancy_expert_num :])
34+
logger.info(
35+
f"layer_index {self._ep_w.layer_num} global_rank {self._ep_w.global_rank_}"
36+
f" new select redundancy_expert_ids : {self.redundancy_expert_ids}"
37+
)
38+
39+
# 准备加载过度变量。
40+
self.experts_up_projs = [None] * self.redundancy_expert_num
41+
self.experts_gate_projs = [None] * self.redundancy_expert_num
42+
self.experts_up_proj_scales = [None] * self.redundancy_expert_num
43+
self.experts_gate_proj_scales = [None] * self.redundancy_expert_num
44+
self.w2_list = [None] * self.redundancy_expert_num
45+
self.w2_scale_list = [None] * self.redundancy_expert_num
46+
self.w1 = [None, None] # weight, weight_scale
47+
self.w2 = [None, None] # weight, weight_scale
48+
return
49+
50+
def load_hf_weights(self, weights):
51+
# 加载冗余专家的权重参数
52+
for i, redundant_expert_id in enumerate(self.redundancy_expert_ids):
53+
i_experts = redundant_expert_id
54+
w1_weight = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w1_weight_name}.weight"
55+
w2_weight = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w2_weight_name}.weight"
56+
w3_weight = f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w3_weight_name}.weight"
57+
if w1_weight in weights:
58+
self.experts_gate_projs[i] = weights[w1_weight]
59+
if w3_weight in weights:
60+
self.experts_up_projs[i] = weights[w3_weight]
61+
if w2_weight in weights:
62+
self.w2_list[i] = weights[w2_weight]
63+
64+
if self._ep_w.quantized_weight:
65+
self._load_weight_scale(weights)
66+
self._fuse()
67+
68+
def _fuse(self):
69+
if self._ep_w.quantized_weight:
70+
self._fuse_weight_scale()
71+
with self._ep_w.lock:
72+
if (
73+
hasattr(self, "experts_up_projs")
74+
and None not in self.experts_up_projs
75+
and None not in self.experts_gate_projs
76+
and None not in self.w2_list
77+
):
78+
w1_list = []
79+
for i_experts in range(self.redundancy_expert_num):
80+
expert_gate_up_proj = torch.cat(
81+
[self.experts_gate_projs[i_experts], self.experts_up_projs[i_experts]], dim=0
82+
)
83+
expert_gate_up_proj = expert_gate_up_proj
84+
w1_list.append(expert_gate_up_proj)
85+
86+
inter_shape, hidden_size = w1_list[0].shape[0], w1_list[0].shape[1]
87+
w1 = torch._utils._flatten_dense_tensors(w1_list).view(len(w1_list), inter_shape, hidden_size)
88+
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
89+
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
90+
if not self._ep_w.quantized_weight and self._ep_w.quant_method is not None:
91+
self.w1 = self._ep_w.quant_method.quantize(w1)
92+
self.w2 = self._ep_w.quant_method.quantize(w2)
93+
else:
94+
self.w1[0] = w1
95+
self.w2[0] = w2
96+
97+
delattr(self, "w2_list")
98+
delattr(self, "experts_up_projs")
99+
delattr(self, "experts_gate_projs")
100+
101+
def _fuse_weight_scale(self):
102+
with self._ep_w.lock:
103+
if (
104+
hasattr(self, "experts_up_proj_scales")
105+
and None not in self.experts_up_proj_scales
106+
and None not in self.experts_gate_proj_scales
107+
and None not in self.w2_scale_list
108+
):
109+
w1_scale_list = []
110+
for i_experts in range(self.redundancy_expert_num):
111+
expert_gate_up_proj_scale = torch.cat(
112+
[self.experts_gate_proj_scales[i_experts], self.experts_up_proj_scales[i_experts]], dim=0
113+
)
114+
w1_scale_list.append(expert_gate_up_proj_scale)
115+
116+
inter_shape, hidden_size = w1_scale_list[0].shape[0], w1_scale_list[0].shape[1]
117+
w1_scale = torch._utils._flatten_dense_tensors(w1_scale_list).view(
118+
len(w1_scale_list), inter_shape, hidden_size
119+
)
120+
inter_shape, hidden_size = self.w2_scale_list[0].shape[0], self.w2_scale_list[0].shape[1]
121+
w2_scale = torch._utils._flatten_dense_tensors(self.w2_scale_list).view(
122+
len(self.w2_scale_list), inter_shape, hidden_size
123+
)
124+
self.w1[1] = w1_scale
125+
self.w2[1] = w2_scale
126+
delattr(self, "w2_scale_list")
127+
delattr(self, "experts_up_proj_scales")
128+
delattr(self, "experts_gate_proj_scales")
129+
130+
def _load_weight_scale(self, weights: Dict[str, torch.Tensor]) -> None:
131+
# 加载冗余专家的scale参数
132+
for i, redundant_expert_id in enumerate(self.redundancy_expert_ids):
133+
i_experts = redundant_expert_id
134+
w1_scale = (
135+
f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w1_weight_name}.{self._ep_w.weight_scale_suffix}"
136+
)
137+
w2_scale = (
138+
f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w2_weight_name}.{self._ep_w.weight_scale_suffix}"
139+
)
140+
w3_scale = (
141+
f"{self._ep_w.weight_prefix}.{i_experts}.{self._ep_w.w3_weight_name}.{self._ep_w.weight_scale_suffix}"
142+
)
143+
if w1_scale in weights:
144+
self.experts_gate_proj_scales[i] = weights[w1_scale]
145+
if w3_scale in weights:
146+
self.experts_up_proj_scales[i] = weights[w3_scale]
147+
if w2_scale in weights:
148+
self.w2_scale_list[i] = weights[w2_scale]
149+
150+
def commit(self):
151+
for index, dest_tensor in enumerate(self._ep_w.w1):
152+
if dest_tensor is not None:
153+
assert isinstance(
154+
dest_tensor, torch.Tensor
155+
), f"dest_tensor should be a torch.Tensor, but got {type(dest_tensor)}"
156+
dest_tensor[-self.redundancy_expert_num :, :, :] = self.w1[index][:, :, :]
157+
158+
for index, dest_tensor in enumerate(self._ep_w.w2):
159+
if dest_tensor is not None:
160+
assert isinstance(
161+
dest_tensor, torch.Tensor
162+
), f"dest_tensor should be a torch.Tensor, but got {type(dest_tensor)}"
163+
dest_tensor[-self.redundancy_expert_num :, :, :] = self.w2[index][:, :, :]
164+
165+
self._ep_w.redundancy_expert_ids_tensor.copy_(
166+
torch.tensor(self.redundancy_expert_ids, dtype=torch.int64, device="cpu")
167+
)

0 commit comments

Comments
 (0)