Skip to content

Commit 18a0f08

Browse files
committed
solve format
1 parent 39a5383 commit 18a0f08

File tree

7 files changed

+111
-76
lines changed

7 files changed

+111
-76
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77

88

99
class FusedMoeWeight(BaseWeight):
10-
def __init__(self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type):
10+
def __init__(
11+
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type
12+
):
1113
super().__init__()
1214
self.w1_weight_name = gate_proj_name
1315
self.w2_weight_name = down_proj_name
@@ -22,53 +24,47 @@ def __init__(self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix,
2224
self.w2_list = [None] * self.n_routed_experts
2325
self.quant_method = None
2426
self.lock = threading.Lock()
25-
27+
2628
def set_quant_method(self, quant_method):
2729
self.quant_method = quant_method
2830
if self.quant_method is not None:
2931
self.quant_method.is_moe = True
3032

31-
def experts(
32-
self,
33-
input_tensor,
34-
router_logits,
35-
top_k,
36-
renormalize,
37-
use_grouped_topk,
38-
topk_group,
39-
num_expert_group
40-
):
33+
def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
4134
topk_weights, topk_ids = FusedMoE.select_experts(
4235
hidden_states=input_tensor,
4336
router_logits=router_logits,
4437
use_grouped_topk=use_grouped_topk,
4538
top_k=top_k,
4639
renormalize=renormalize,
4740
topk_group=topk_group,
48-
num_expert_group=num_expert_group
41+
num_expert_group=num_expert_group,
4942
)
5043
if self.quant_method is not None:
51-
fused_experts(input_tensor,
52-
w1=self.w1[0],
53-
w2=self.w2[0],
54-
topk_weights=topk_weights,
55-
topk_ids=topk_ids,
56-
inplace=False,
57-
use_fp8_w8a8=True,
58-
use_int8_w8a16=False,
59-
w1_scale=self.w1[1],
60-
w2_scale=self.w2[1],
61-
a1_scale=None,
62-
a2_scale=None)
44+
fused_experts(
45+
input_tensor,
46+
w1=self.w1[0],
47+
w2=self.w2[0],
48+
topk_weights=topk_weights,
49+
topk_ids=topk_ids,
50+
inplace=False,
51+
use_fp8_w8a8=True,
52+
use_int8_w8a16=False,
53+
w1_scale=self.w1[1],
54+
w2_scale=self.w2[1],
55+
a1_scale=None,
56+
a2_scale=None,
57+
)
6358
return
64-
fused_experts(hidden_states=input_tensor,
59+
fused_experts(
60+
hidden_states=input_tensor,
6561
w1=self.w1,
6662
w2=self.w2,
6763
topk_weights=topk_weights,
6864
topk_ids=topk_ids,
69-
inplace=True
65+
inplace=True,
7066
)
71-
67+
7268
def fuse(self):
7369
with self.lock:
7470
if (
@@ -120,15 +116,14 @@ def load_hf_weights(self, weights):
120116
self.w2_list[i_experts] = weights[w2_weight][
121117
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
122118
]
123-
119+
124120
self.fuse()
125121

126-
127122
def _cuda(self, cpu_tensor):
128123
if self.tp_rank_ is None:
129124
return cpu_tensor.contiguous().to(self.data_type_).cuda()
130125
else:
131126
return cpu_tensor.contiguous().to(self.data_type_).cuda(self.tp_rank_)
132-
127+
133128
def verify_load(self):
134129
return self.w1 is not None and self.w2 is not None

lightllm/common/basemodel/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# from lightllm.common.layers.mm import MM
44
from .base_layer_weight import BaseLayerWeight
5-
from .meta_weights import MMWeight,FusedMoeWeight
5+
from .meta_weights import MMWeight, FusedMoeWeight
66
from lightllm.utils.log_utils import init_logger
77

88
logger = init_logger(__name__)

lightllm/common/quantization/vllm_quant.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,9 @@ def quantize(self, weight: torch.Tensor):
6464
return self.quantize_moe(weight)
6565
qweight, weight_scale = ops.scaled_fp8_quant(weight.cuda(), scale=None, use_per_token_if_dynamic=True)
6666
return qweight.transpose(0, 1), weight_scale
67-
67+
6868
def quantize_moe(self, weight):
6969
num_experts = weight.shape[0]
70-
out_dim = weight.shape[1]
7170
qweights = []
7271
weight_scales = []
7372
qweights = torch.empty_like(weight, dtype=torch.float8_e4m3fn).cuda()

lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class BaiChuan7bTransformerLayerWeight(LlamaTransformerLayerWeight):
99
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
1010
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
1111
return
12-
12+
1313
def _init_config(self):
1414
self.network_config_["num_key_value_heads"] = self.network_config_["num_attention_heads"]
1515
self.n_embed = self.network_config_["hidden_size"]

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
context_attention_fwd,
1010
context_attention_fwd_no_prompt_cache,
1111
)
12+
1213
from lightllm.models.deepseek2.triton_kernel.flash_decoding import token_decode_attention_flash_decoding
1314
from lightllm.models.deepseek2.layer_infer.fused_moe import fused_experts, grouped_topk
1415
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
@@ -20,7 +21,9 @@
2021

2122

2223
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
23-
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[], disable_qk_absorb=False, disable_vo_absorb=False):
24+
def __init__(
25+
self, layer_num, tp_rank, world_size, network_config, mode=[], disable_qk_absorb=False, disable_vo_absorb=False
26+
):
2427
self.tp_k_head_num_ = 1
2528
self.tp_v_head_num_ = 1
2629
self.qk_nope_head_dim = network_config["qk_nope_head_dim"]
@@ -207,7 +210,7 @@ def _moe_ffn(
207210
renormalize=self.norm_topk_prob,
208211
use_grouped_topk=self.n_group,
209212
topk_group=self.topk_group,
210-
num_expert_group=self.n_group
213+
num_expert_group=self.n_group,
211214
)
212215

213216
hidden_states.mul_(self.routed_scaling_factor)

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,14 @@
22
import math
33
import numpy as np
44
from lightllm.common.basemodel import TransformerLayerWeight
5-
from lightllm.common.basemodel.layer_weights.meta_weights import ROWMMWeight, COLMMWeight, NormWeight, CustomMMWeight, FusedMoeWeight, CustomBMMWeight
5+
from lightllm.common.basemodel.layer_weights.meta_weights import (
6+
ROWMMWeight,
7+
COLMMWeight,
8+
NormWeight,
9+
CustomMMWeight,
10+
FusedMoeWeight,
11+
CustomBMMWeight,
12+
)
613
from functools import partial
714

815

@@ -19,16 +26,32 @@ def fuse_q_kb(self, A, B):
1926
k_nope_proj_ = k_b_proj_.unsqueeze(0)
2027
k_nope_proj_ = k_nope_proj_.to(torch.float64)
2128

22-
return self._cuda(torch.matmul(q_nope_proj_, k_nope_proj_).view(-1, self.tp_q_head_num_ * self.kv_lora_rank).transpose(0, 1))
29+
return self._cuda(
30+
torch.matmul(q_nope_proj_, k_nope_proj_).view(-1, self.tp_q_head_num_ * self.kv_lora_rank).transpose(0, 1)
31+
)
32+
2333

2434
def fuse_vb_o(self, A, B):
2535
v_b_proj_ = A.weight
26-
o_weight_ = B.weight.transpose(0, 1).view(self.tp_q_head_num_, self.qk_nope_head_dim, -1).contiguous().to(self.data_type_).cpu()
27-
return self._cuda(torch.matmul(v_b_proj_.to(torch.float64), o_weight_.to(torch.float64)).view(-1, self.network_config_["hidden_size"]))
36+
o_weight_ = (
37+
B.weight.transpose(0, 1)
38+
.view(self.tp_q_head_num_, self.qk_nope_head_dim, -1)
39+
.contiguous()
40+
.to(self.data_type_)
41+
.cpu()
42+
)
43+
return self._cuda(
44+
torch.matmul(v_b_proj_.to(torch.float64), o_weight_.to(torch.float64)).view(
45+
-1, self.network_config_["hidden_size"]
46+
)
47+
)
48+
2849

2950
def load_q_rope(self, A, q_weight_):
3051
q_split_n_embed_with_rope = A.split_n_embed
31-
q_weight_ = q_weight_[q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), :]
52+
q_weight_ = q_weight_[
53+
q_split_n_embed_with_rope * self.tp_rank_ : q_split_n_embed_with_rope * (self.tp_rank_ + 1), :
54+
]
3255
q_weight_ = q_weight_.transpose(0, 1).contiguous()
3356
q_nope_proj_, q_rope_proj_ = torch.split(
3457
q_weight_.view(-1, self.tp_q_head_num_, self.qk_nope_head_dim + self.qk_rope_head_dim),
@@ -37,6 +60,7 @@ def load_q_rope(self, A, q_weight_):
3760
)
3861
return self._cuda(q_rope_proj_.reshape(-1, self.qk_rope_head_dim * self.tp_q_head_num_).transpose(0, 1))
3962

63+
4064
def load_kb(self, A, kv_b_proj_):
4165
kv_b_proj_ = kv_b_proj_
4266
k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[
@@ -47,22 +71,31 @@ def load_kb(self, A, kv_b_proj_):
4771
return k_b_proj_.contiguous().to(self.data_type_).cpu()
4872
return self._cuda(k_b_proj_)
4973

74+
5075
def load_vb(self, A, kv_b_proj_):
5176
kv_b_proj_ = kv_b_proj_
52-
v_b_proj_ = kv_b_proj_.T.view(
53-
self.kv_lora_rank,
54-
self.num_attention_heads,
55-
self.qk_nope_head_dim * 2,
56-
)[:, :, self.qk_nope_head_dim :].transpose(0, 1)
57-
v_b_proj_ = v_b_proj_[
58-
self.tp_q_head_num_ * self.tp_rank_ : self.tp_q_head_num_ * (self.tp_rank_ + 1), :, :
59-
]
77+
v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[
78+
:, :, self.qk_nope_head_dim :
79+
].transpose(0, 1)
80+
v_b_proj_ = v_b_proj_[self.tp_q_head_num_ * self.tp_rank_ : self.tp_q_head_num_ * (self.tp_rank_ + 1), :, :]
6081
if A.wait_fuse:
6182
return v_b_proj_.contiguous().to(self.data_type_).cpu()
6283
return self._cuda(v_b_proj_)
6384

85+
6486
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
65-
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None, disable_qk_absorb=False, disable_vo_absorb=False):
87+
def __init__(
88+
self,
89+
layer_num,
90+
tp_rank,
91+
world_size,
92+
data_type,
93+
network_config,
94+
mode=[],
95+
quant_cfg=None,
96+
disable_qk_absorb=False,
97+
disable_vo_absorb=False,
98+
):
6699
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
67100
self.is_moe = (
68101
self.network_config_["n_routed_experts"] is not None
@@ -86,9 +119,11 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo
86119
self.fuse_pairs = {"q_b_proj_&k_b_proj_": "fuse_qk_weight_"}
87120
if not self.disable_vo_absorb:
88121
self.fuse_pairs["v_b_proj_&o_weight_"] = "fuse_vo_weight_"
89-
self.fuse_pairs.update({
90-
"gate_proj&up_proj": "gate_up_proj",
91-
})
122+
self.fuse_pairs.update(
123+
{
124+
"gate_proj&up_proj": "gate_up_proj",
125+
}
126+
)
92127

93128
self.init_qkvo()
94129
if self.is_moe:
@@ -115,7 +150,10 @@ def init_qkvo(self):
115150
rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"
116151
else:
117152
self.q_a_proj_ = ROWMMWeight(
118-
f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight", self.data_type_, self.q_lora_rank, disable_tp=True
153+
f"model.layers.{self.layer_num_}.self_attn.q_a_proj.weight",
154+
self.data_type_,
155+
self.q_lora_rank,
156+
disable_tp=True,
119157
)
120158
self.q_b_proj_ = CustomMMWeight(
121159
f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight",
@@ -126,10 +164,7 @@ def init_qkvo(self):
126164
)
127165
rope_weight_name = f"model.layers.{self.layer_num_}.self_attn.q_b_proj.weight"
128166
self.q_rope_proj_ = CustomMMWeight(
129-
rope_weight_name,
130-
self.data_type_,
131-
q_split_n_embed_with_rope,
132-
custom_load=partial(load_q_rope, self)
167+
rope_weight_name, self.data_type_, q_split_n_embed_with_rope, custom_load=partial(load_q_rope, self)
133168
)
134169
self.kv_a_proj_with_mqa_ = ROWMMWeight(
135170
f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight",
@@ -142,50 +177,47 @@ def init_qkvo(self):
142177
self.data_type_,
143178
None,
144179
wait_fuse=not self.disable_qk_absorb,
145-
custom_load=partial(load_kb, self)
180+
custom_load=partial(load_kb, self),
146181
)
147182
self.v_b_proj_ = CustomBMMWeight(
148183
f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight",
149184
self.data_type_,
150185
None,
151186
wait_fuse=not self.disable_vo_absorb,
152187
custom_load=partial(load_vb, self),
153-
custom_fuse=partial(fuse_vb_o, self)
188+
custom_fuse=partial(fuse_vb_o, self),
154189
)
155190
self.o_weight_ = COLMMWeight(
156-
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight", self.data_type_, q_split_n_embed, wait_fuse=not self.disable_vo_absorb,
191+
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight",
192+
self.data_type_,
193+
q_split_n_embed,
194+
wait_fuse=not self.disable_vo_absorb,
157195
)
158196

159197
def _load_mlp(self, mlp_prefix, split_inter_size):
160198
self.gate_proj = ROWMMWeight(
161199
f"{mlp_prefix}.gate_proj.weight", self.data_type_, split_inter_size, wait_fuse=True
162200
)
163-
self.up_proj = ROWMMWeight(
164-
f"{mlp_prefix}.up_proj.weight", self.data_type_, split_inter_size, wait_fuse=True
165-
)
166-
self.down_proj = COLMMWeight(
167-
f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size
168-
)
201+
self.up_proj = ROWMMWeight(f"{mlp_prefix}.up_proj.weight", self.data_type_, split_inter_size, wait_fuse=True)
202+
self.down_proj = COLMMWeight(f"{mlp_prefix}.down_proj.weight", self.data_type_, split_inter_size)
169203

170204
def init_moe(self):
171205
moe_intermediate_size = self.network_config_["moe_intermediate_size"]
172206
self.moe_gate = ROWMMWeight(
173207
f"model.layers.{self.layer_num_}.mlp.gate.weight", self.data_type_, moe_intermediate_size, disable_tp=True
174208
)
175-
shared_intermediate_size = (
176-
moe_intermediate_size * self.network_config_["n_shared_experts"]
177-
)
209+
shared_intermediate_size = moe_intermediate_size * self.network_config_["n_shared_experts"]
178210
shared_split_inter_size = shared_intermediate_size // self.world_size_
179211
self._load_mlp(f"model.layers.{self.layer_num_}.mlp.shared_experts", shared_split_inter_size)
180-
212+
181213
self.experts = FusedMoeWeight(
182214
gate_proj_name="gate_proj",
183215
down_proj_name="down_proj",
184216
up_proj_name="up_proj",
185217
weight_prefix=f"model.layers.{self.layer_num_}.mlp.experts",
186218
n_routed_experts=self.n_routed_experts,
187219
split_inter_size=moe_intermediate_size // self.world_size_,
188-
data_type=self.data_type_
220+
data_type=self.data_type_,
189221
)
190222

191223
def init_ffn(self):

lightllm/models/deepseek2/model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def _init_weights(self):
6666
mode=self.mode,
6767
quant_cfg=self.quant_cfg,
6868
disable_qk_absorb=self.disable_qk_absorb,
69-
disable_vo_absorb=self.disable_vo_absorb
69+
disable_vo_absorb=self.disable_vo_absorb,
7070
)
7171
for i in range(self.config["n_layer"])
7272
]
@@ -80,7 +80,7 @@ def _init_weights(self):
8080
self.pre_post_weight.verify_load()
8181
[weight.verify_load() for weight in self.trans_layers_weight]
8282
return
83-
83+
8484
def _init_infer_layer(self):
8585
self.pre_infer = self.pre_layer_infer_class(
8686
tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode
@@ -90,7 +90,13 @@ def _init_infer_layer(self):
9090
)
9191
self.layers_infer = [
9292
self.transformer_layer_infer_class(
93-
i, tp_rank=self.tp_rank_, world_size=self.world_size_, network_config=self.config, mode=self.mode, disable_qk_absorb=self.disable_qk_absorb, disable_vo_absorb=self.disable_vo_absorb
93+
i,
94+
tp_rank=self.tp_rank_,
95+
world_size=self.world_size_,
96+
network_config=self.config,
97+
mode=self.mode,
98+
disable_qk_absorb=self.disable_qk_absorb,
99+
disable_vo_absorb=self.disable_vo_absorb,
94100
)
95101
for i in range(self.config["n_layer"])
96102
]

0 commit comments

Comments
 (0)