Skip to content

Commit 3ee963e

Browse files
author
sangchengmeng
committed
1204
1 parent cd9c7ee commit 3ee963e

File tree

5 files changed

+63
-81
lines changed

5 files changed

+63
-81
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 51 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ def create_tp_moe_wegiht_obj(
2020
network_config: Dict[str, Any],
2121
layer_num: int,
2222
quant_cfg: Quantcfg = None,
23+
fused_gate_up: bool = False,
24+
gate_up_proj_name: str = None,
2325
) -> Union["FusedMoeWeightTP", "FusedAWQMARLINMoeWeightTP"]:
2426
quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
2527
if quant_method is not None and quant_method.method_name == "awq_marlin":
@@ -36,6 +38,8 @@ def create_tp_moe_wegiht_obj(
3638
network_config=network_config,
3739
layer_num=layer_num,
3840
quant_cfg=quant_cfg,
41+
fused_gate_up=fused_gate_up,
42+
gate_up_proj_name=gate_up_proj_name,
3943
)
4044
else:
4145
return FusedMoeWeightTP(
@@ -51,6 +55,8 @@ def create_tp_moe_wegiht_obj(
5155
network_config=network_config,
5256
layer_num=layer_num,
5357
quant_cfg=quant_cfg,
58+
fused_gate_up=fused_gate_up,
59+
gate_up_proj_name=gate_up_proj_name,
5460
)
5561

5662

@@ -69,6 +75,8 @@ def __init__(
6975
network_config: Dict[str, Any],
7076
layer_num: int,
7177
quant_cfg: Quantcfg = None,
78+
fused_gate_up: bool = False,
79+
gate_up_proj_name: str = None,
7280
) -> None:
7381
super().__init__()
7482
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")
@@ -79,6 +87,8 @@ def __init__(
7987
self.w1_weight_name = gate_proj_name
8088
self.w2_weight_name = down_proj_name
8189
self.w3_weight_name = up_proj_name
90+
self.fused_gate_up = fused_gate_up
91+
self.gate_up_proj_name = gate_up_proj_name
8292

8393
self.e_score_correction_bias_name = e_score_correction_bias_name
8494
self.weight_prefix = weight_prefix
@@ -181,8 +191,6 @@ def _fuse(self):
181191

182192
inter_shape, hidden_size = self.w2_list[0].shape[0], self.w2_list[0].shape[1]
183193
w2 = torch._utils._flatten_dense_tensors(self.w2_list).view(len(self.w2_list), inter_shape, hidden_size)
184-
if self.fused_gate_up:
185-
w2 = w2.transpose(1, 2).contiguous()
186194
if not self.quantized_weight and self.quant_method is not None:
187195
qw1, qw1_scale, qw1_zero_point = self.quant_method.quantize(w1)
188196
qw2, qw2_scale, qw2_zero_point = self.quant_method.quantize(w2)
@@ -228,51 +236,49 @@ def _fuse_weight_scale(self):
228236
delattr(self, "experts_up_proj_scales")
229237
delattr(self, "experts_gate_proj_scales")
230238

239+
def fused_gate_up_weights_load(self, weights):
240+
key_gate_up = f"{self.weight_prefix}.{self.gate_up_proj_name}" # ...experts.gate_up_proj
241+
key_down = f"{self.weight_prefix}.{self.w2_weight_name}"
242+
if (key_gate_up not in weights) or (key_down not in weights):
243+
return
244+
gate_up = weights[key_gate_up]
245+
down = weights[key_down]
246+
_, _, I_double = gate_up.shape
247+
I_single = I_double // 2
248+
start = self.tp_rank_ * self.split_inter_size
249+
end = (self.tp_rank_ + 1) * self.split_inter_size
250+
251+
for i_experts in range(self.n_routed_experts):
252+
gate_up_2d = gate_up[i_experts]
253+
self.experts_gate_projs[i_experts] = gate_up_2d[:, :I_single][:, start:end].t().contiguous()
254+
self.experts_up_projs[i_experts] = gate_up_2d[:, I_single:][:, start:end].t().contiguous()
255+
self.w2_list[i_experts] = down[i_experts].t()[:, start:end].contiguous()
256+
257+
def normal_weights_load(self, weights):
258+
for i_experts in range(self.n_routed_experts):
259+
260+
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
261+
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
262+
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
263+
264+
start = self.tp_rank_ * self.split_inter_size
265+
end = (self.tp_rank_ + 1) * self.split_inter_size
266+
267+
if w1_weight in weights:
268+
self.experts_gate_projs[i_experts] = weights[w1_weight][start:end, :]
269+
if w3_weight in weights:
270+
self.experts_up_projs[i_experts] = weights[w3_weight][start:end, :]
271+
if w2_weight in weights:
272+
self.w2_list[i_experts] = weights[w2_weight][start:end]
273+
231274
def load_hf_weights(self, weights):
232275
if self.e_score_correction_bias_name in weights:
233276
self.e_score_correction_bias = self._cuda(weights[self.e_score_correction_bias_name])
234-
self.fused_gate_up = self.w3_weight_name is None # gate_up: [E,H,2I] down: [E,I,H]
235-
key_gateup_3d = f"{self.weight_prefix}.{self.w1_weight_name}" # ...experts.gate_up_proj
236-
key_down_3d = f"{self.weight_prefix}.{self.w2_weight_name}"
237-
238-
if self.fused_gate_up and (key_gateup_3d in weights) and (key_down_3d in weights):
239-
gate_up_3d = weights[key_gateup_3d]
240-
down_3d = weights[key_down_3d]
241-
assert gate_up_3d.dim() == 3 and down_3d.dim() == 3
242-
243-
E_ckpt, H_, twoE = gate_up_3d.shape
244-
assert E_ckpt == self.n_routed_experts, f"experts mismatch: ckpt {E_ckpt} vs cfg {self.n_routed_experts}"
245-
Eint_total = twoE // 2
246-
start, end = self.tp_rank_ * self.split_inter_size, (self.tp_rank_ + 1) * self.split_inter_size
247-
assert end <= Eint_total, "TP split exceeds total expert-intermediate size"
248-
249-
for i in range(self.n_routed_experts):
250-
gu2d = gate_up_3d[i]
251-
gate2d = gu2d[:, :Eint_total][:, start:end].t().contiguous()
252-
up2d = gu2d[:, Eint_total:][:, start:end].t().contiguous()
253-
self.experts_gate_projs[i] = gate2d
254-
self.experts_up_projs[i] = up2d
255-
256-
self.w2_list[i] = down_3d[i][start:end, :].contiguous()
277+
if self.fused_gate_up: # gate_up: [E,H,2I] down: [E,I,H]
278+
self.fused_gate_up_weights_load(weights)
257279
else:
258-
for i_experts in range(self.n_routed_experts):
259-
w1_weight = f"{self.weight_prefix}.{i_experts}.{self.w1_weight_name}.weight"
260-
w2_weight = f"{self.weight_prefix}.{i_experts}.{self.w2_weight_name}.weight"
261-
w3_weight = f"{self.weight_prefix}.{i_experts}.{self.w3_weight_name}.weight"
262-
263-
if w1_weight in weights:
264-
self.experts_gate_projs[i_experts] = weights[w1_weight][
265-
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
266-
]
267-
if w3_weight in weights:
268-
self.experts_up_projs[i_experts] = weights[w3_weight][
269-
self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1), :
270-
]
271-
272-
if w2_weight in weights:
273-
self.w2_list[i_experts] = weights[w2_weight][
274-
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
275-
]
280+
self.normal_weights_load(weights)
281+
276282
if self.quant_method is not None:
277283
if self.fused_gate_up:
278284
raise ValueError("qwen3_vl_moe not support quant now")
@@ -342,6 +348,8 @@ def __init__(
342348
network_config: Dict[str, Any],
343349
layer_num: int,
344350
quant_cfg: Quantcfg = None,
351+
fused_gate_up: bool = False,
352+
gate_up_proj_name: str = None,
345353
) -> None:
346354
super().__init__()
347355
self.quant_method = quant_cfg.get_quant_method(layer_num, "fused_moe")

lightllm/models/qwen3_vl/layer_weights/pre_and_post_layer_weight.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import torch
21
import numpy as np
32
from lightllm.common.basemodel import PreAndPostLayerWeight
43

lightllm/models/qwen3_vl/layer_weights/transformers_layer_weight.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,4 @@
1-
import os
2-
import torch
3-
import math
4-
import numpy as np
51
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
6-
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
7-
from lightllm.common.basemodel.layer_weights.meta_weights import (
8-
ROWMMWeight,
9-
MultiROWMMWeight,
10-
COLMMWeight,
11-
NormWeight,
12-
FusedMoeWeightTP,
13-
FusedMoeWeightEP,
14-
ROWBMMWeight,
15-
)
162

173

184
class Qwen3VLTransformerLayerWeight(Qwen3TransformerLayerWeight): # 后面看要不要改

lightllm/models/qwen3_vl_moe/layer_infer/transformer_layer_infer.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,7 @@ def _get_qkv(
3535
) -> Tuple[torch.Tensor, torch.Tensor]:
3636
input = input.view(-1, self.embed_dim_)
3737
q = layer_weight.q_proj.mm(input)
38-
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
39-
cache_kv = layer_weight.kv_proj.mm(
40-
input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_)
41-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
38+
cache_kv = layer_weight.kv_proj.mm(input).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
4239
rmsnorm_forward(
4340
q.view(-1, self.head_dim_),
4441
weight=layer_weight.q_norm_weight_.weight,

lightllm/models/qwen3_vl_moe/layer_weights/transformers_layer_weight.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,6 @@
11
import os
2-
import torch
3-
import math
4-
import numpy as np
5-
from lightllm.models.qwen3.layer_weights.transformer_layer_weight import Qwen3TransformerLayerWeight
62
from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight
7-
from lightllm.common.basemodel.layer_weights.meta_weights import (
8-
ROWMMWeight,
9-
MultiROWMMWeight,
10-
COLMMWeight,
11-
NormWeight,
12-
FusedMoeWeightTP,
13-
FusedMoeWeightEP,
14-
ROWBMMWeight,
15-
)
3+
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, FusedMoeWeightEP, create_tp_moe_wegiht_obj
164

175

186
class Qwen3VLMOETransformerLayerWeight(Qwen3MOETransformerLayerWeight):
@@ -39,7 +27,7 @@ def _init_weight_names(self):
3927
def _init_moe(self):
4028
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
4129
self.moe_gate = ROWMMWeight(
42-
weight_name=f"model.language_model.layers.{self.layer_num_}.mlp.gate.weight",
30+
weight_names=f"model.language_model.layers.{self.layer_num_}.mlp.gate.weight",
4331
data_type=self.data_type_,
4432
layer_num=self.layer_num_,
4533
name="moe_gate",
@@ -50,10 +38,10 @@ def _init_moe(self):
5038
assert moe_mode in ["EP", "TP"]
5139

5240
if moe_mode == "TP":
53-
self.experts = FusedMoeWeightTP(
54-
gate_proj_name="gate_up_proj",
41+
self.experts = create_tp_moe_wegiht_obj(
42+
gate_proj_name="gate_proj",
5543
down_proj_name="down_proj",
56-
up_proj_name=None,
44+
up_proj_name="up_proj",
5745
e_score_correction_bias_name="",
5846
weight_prefix=f"model.language_model.layers.{self.layer_num_}.mlp.experts",
5947
n_routed_experts=self.n_routed_experts,
@@ -63,19 +51,23 @@ def _init_moe(self):
6351
layer_num=self.layer_num_,
6452
quant_cfg=self.quant_cfg,
6553
num_fused_shared_experts=0,
54+
fused_gate_up=True,
55+
gate_up_proj_name="gate_up_proj",
6656
)
6757
elif moe_mode == "EP":
6858
self.experts = FusedMoeWeightEP(
69-
gate_proj_name="gate_up_proj",
59+
gate_proj_name="gate_proj",
7060
down_proj_name="down_proj",
71-
up_proj_name=None,
61+
up_proj_name="up_proj",
7262
e_score_correction_bias_name="",
7363
weight_prefix=f"model.language_model.layers.{self.layer_num_}.mlp.experts",
7464
n_routed_experts=self.n_routed_experts,
7565
data_type=self.data_type_,
7666
network_config=self.network_config_,
7767
layer_num=self.layer_num_,
7868
quant_cfg=self.quant_cfg,
69+
fused_gate_up=True,
70+
gate_up_proj_name="gate_up_proj",
7971
)
8072
else:
8173
raise ValueError(f"Unsupported moe mode: {moe_mode}")

0 commit comments

Comments
 (0)