Skip to content

Commit 89c248f

Browse files
committed
refactor deepseekv2
1 parent 7c305e4 commit 89c248f

File tree

7 files changed

+273
-227
lines changed

7 files changed

+273
-227
lines changed

lightllm/common/basemodel/layer_weights/meta_weights/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
COLMMWeight,
88
MultiROWMMWeight,
99
MultiROWMMWeightNoTP,
10-
CustomMMWeight,
11-
CustomBMMWeight,
10+
MultiCOLMMWeight,
11+
ROWBMMWeight,
12+
COLBMMWeight,
1213
)
1314
from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight
1415
from .fused_moe_weight import FusedMoeWeight

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,22 @@
22
from .base_weight import BaseWeight
33
from lightllm.utils.dist_utils import get_world_size, get_rank
44
import threading
5-
from vllm.model_executor.layers.fused_moe import FusedMoE
6-
from vllm.model_executor.layers.fused_moe import fused_experts
5+
from lightllm.common.quantization import vLLMFP8w8a8QuantizationMethod
6+
7+
try:
8+
HAS_VLLM = True
9+
from vllm.model_executor.layers.fused_moe import FusedMoE
10+
from vllm.model_executor.layers.fused_moe import fused_experts
11+
except:
12+
HAS_VLLM = False
713

814

915
class FusedMoeWeight(BaseWeight):
1016
def __init__(
1117
self, gate_proj_name, down_proj_name, up_proj_name, weight_prefix, n_routed_experts, split_inter_size, data_type
1218
):
1319
super().__init__()
20+
assert HAS_VLLM, "vllm is not installed, you can't use FusedMoeWeight"
1421
self.w1_weight_name = gate_proj_name
1522
self.w2_weight_name = down_proj_name
1623
self.w3_weight_name = up_proj_name
@@ -26,9 +33,10 @@ def __init__(
2633
self.lock = threading.Lock()
2734

2835
def set_quant_method(self, quant_method):
29-
self.quant_method = quant_method
30-
if self.quant_method is not None:
31-
self.quant_method.is_moe = True
36+
if isinstance(self.quant_method, vLLMFP8w8a8QuantizationMethod):
37+
self.quant_method = quant_method
38+
if self.quant_method is not None:
39+
self.quant_method.is_moe = True
3240

3341
def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_topk, topk_group, num_expert_group):
3442
topk_weights, topk_ids = FusedMoE.select_experts(
@@ -40,32 +48,22 @@ def experts(self, input_tensor, router_logits, top_k, renormalize, use_grouped_t
4048
topk_group=topk_group,
4149
num_expert_group=num_expert_group,
4250
)
43-
if self.quant_method is not None:
44-
fused_experts(
45-
input_tensor,
46-
w1=self.w1[0],
47-
w2=self.w2[0],
48-
topk_weights=topk_weights,
49-
topk_ids=topk_ids,
50-
inplace=False,
51-
use_fp8_w8a8=True,
52-
use_int8_w8a16=False,
53-
w1_scale=self.w1[1],
54-
w2_scale=self.w2[1],
55-
a1_scale=None,
56-
a2_scale=None,
57-
)
58-
return
51+
w1, w1_scale = self.w1
52+
w2, w2_scale = self.w2
53+
use_fp8_w8a8 = self.quant_method is not None
5954
fused_experts(
6055
hidden_states=input_tensor,
61-
w1=self.w1,
62-
w2=self.w2,
56+
w1=w1,
57+
w2=w2,
6358
topk_weights=topk_weights,
6459
topk_ids=topk_ids,
6560
inplace=True,
61+
use_fp8_w8a8=use_fp8_w8a8,
62+
w1_scale=w1_scale,
63+
w2_scale=w2_scale,
6664
)
6765

68-
def fuse(self):
66+
def _fuse(self):
6967
with self.lock:
7068
if (
7169
hasattr(self, "experts_up_projs")
@@ -91,8 +89,8 @@ def fuse(self):
9189
self.w1 = self.quant_method.quantize(self.w1)
9290
self.w2 = self.quant_method.quantize(self.w2)
9391
else:
94-
self.w1 = self._cuda(self.w1)
95-
self.w2 = self._cuda(self.w2)
92+
self.w1 = [self._cuda(self.w1), None]
93+
self.w2 = [self._cuda(self.w2), None]
9694
delattr(self, "w2_list")
9795
delattr(self, "experts_up_projs")
9896
delattr(self, "experts_gate_projs")
@@ -117,7 +115,7 @@ def load_hf_weights(self, weights):
117115
:, self.split_inter_size * self.tp_rank_ : self.split_inter_size * (self.tp_rank_ + 1)
118116
]
119117

120-
self.fuse()
118+
self._fuse()
121119

122120
def _cuda(self, cpu_tensor):
123121
if self.tp_rank_ is None:

lightllm/common/basemodel/layer_weights/meta_weights/mm_weight.py

Lines changed: 90 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,9 @@
44

55

66
class MMWeightTpl(BaseWeightTpl):
7-
def __init__(self, data_type, split_n_embed):
7+
def __init__(self, data_type):
88
super().__init__()
99
self.data_type_ = data_type
10-
self.start = split_n_embed * self.tp_rank_
11-
self.end = split_n_embed * (self.tp_rank_ + 1)
1210
self.quant_method = None
1311
self.weight = None
1412
self.bias = None
@@ -40,7 +38,9 @@ def _post_load_weights(self):
4038

4139
class MMWeight(MMWeightTpl):
4240
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
43-
super().__init__(data_type, split_n_embed)
41+
super().__init__(data_type)
42+
self.start = split_n_embed * self.tp_rank_
43+
self.end = split_n_embed * (self.tp_rank_ + 1)
4444
self.weight_name = weight_name
4545
self.bias_name = bias_name
4646

@@ -72,7 +72,7 @@ def load_hf_weights(self, weights):
7272
return
7373

7474

75-
class ROWMMWeightNoTP(MMWeight):
75+
class ROWMMWeightNoTP(ROWMMWeight):
7676
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
7777
super().__init__(weight_name, data_type, split_n_embed, bias_name)
7878
self.start = 0
@@ -98,13 +98,20 @@ def load_hf_weights(self, weights):
9898

9999

100100
class MultiMMWeight(MMWeightTpl):
101-
def __init__(self, weight_names, data_type, split_n_embed, bias_names=None):
102-
super().__init__(data_type, split_n_embed)
101+
def __init__(self, weight_names, data_type, split_n_embeds, bias_names=[]):
102+
super().__init__(data_type)
103+
if isinstance(split_n_embeds, int):
104+
self.split_n_embeds = [split_n_embeds] * len(weight_names)
105+
else:
106+
self.split_n_embeds = split_n_embeds
107+
108+
self.starts = [i * self.tp_rank_ for i in self.split_n_embeds]
109+
self.ends = [i * (self.tp_rank_ + 1) for i in self.split_n_embeds]
103110
self.weight_names = weight_names
104111
self.bias_names = bias_names
105112
self.weights = [None] * len(self.weight_names)
106113
self.biases = [None] * len(self.bias_names)
107-
self.has_bias = all(b is not None for b in self.bias_names)
114+
self.has_bias = all(b is not None for b in self.bias_names) and len(bias_names) > 0
108115

109116
def verify_load(self):
110117
load_ok = True
@@ -117,7 +124,7 @@ def verify_load(self):
117124

118125

119126
class MultiROWMMWeight(MultiMMWeight):
120-
def __init__(self, weight_names, data_type, split_n_embed, bias_names=None):
127+
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
121128
super().__init__(weight_names, data_type, split_n_embed, bias_names)
122129

123130
def _fuse(self):
@@ -134,86 +141,48 @@ def load_hf_weights(self, weights):
134141
for i in range(len(self.weight_names)):
135142
if self.weight_names[i] in weights:
136143
weight = weights[self.weight_names[i]].to(self.data_type_)
137-
self.weights[i] = weight[self.start : self.end]
144+
self.weights[i] = weight[self.starts[i] : self.ends[i]]
138145
if self.has_bias and self.bias_names[i] in weights:
139146
bias = weights[self.bias_names[i]].to(self.data_type_)
140-
self.biases[i] = bias[self.start : self.end]
147+
self.biases[i] = bias[self.starts[i] : self.ends[i]]
141148
self._fuse()
142149
return
143150

144151

145152
class MultiROWMMWeightNoTP(MultiROWMMWeight):
146-
def __init__(self, weight_names, data_type, split_n_embed, bias_names=None):
153+
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
147154
super().__init__(weight_names, data_type, split_n_embed, bias_names)
148-
self.start = 0
149-
self.end = split_n_embed
155+
self.starts = [0 for i in self.split_n_embeds]
156+
self.ends = [i for i in self.split_n_embeds]
150157

151158

152-
class CustomMMWeight(ROWMMWeight):
153-
def __init__(
154-
self,
155-
weight_name,
156-
data_type,
157-
split_n_embed,
158-
bias_name=None,
159-
wait_fuse=False,
160-
disable_tp=False,
161-
custom_load=None,
162-
custom_fuse=None,
163-
):
164-
super().__init__(weight_name, data_type, split_n_embed, bias_name, wait_fuse=wait_fuse, disable_tp=disable_tp)
165-
self.custom_load = custom_load
166-
self.custom_fuse = custom_fuse
167-
168-
def fuse(self, B, op=None):
169-
if self.custom_fuse is None:
170-
super().fuse(B, op)
171-
else:
172-
weight = self.custom_fuse(self, B)
173-
self.post_load_weights(weight)
159+
class MultiCOLMMWeight(MultiROWMMWeight):
160+
def __init__(self, weight_names, data_type, split_n_embed, bias_names=[]):
161+
super().__init__(weight_names, data_type, split_n_embed, bias_names)
174162

175163
def load_hf_weights(self, weights):
176-
if self.custom_load is None:
177-
super().load_hf_weights(weights)
178-
else:
179-
weight = None
180-
if self.weight_name in weights:
181-
weight = self.custom_load(self, self.pre_load_weights(weights[self.weight_name]))
182-
if weight is None:
183-
return
184-
if self.wait_fuse:
185-
self.weight = weight
186-
return
187-
self.post_load_weights(weight)
164+
weight = None
165+
for i in range(len(self.weight_names)):
166+
if self.weight_names[i] in weights:
167+
weight = weights[self.weight_names[i]].to(self.data_type_)
168+
self.weights[i] = weight[:, self.starts[i] : self.ends[i]]
169+
if self.has_bias and self.bias_names[i] in weights:
170+
bias = weights[self.bias_names[i]].to(self.data_type_)
171+
self.biases[i] = bias[:, self.starts[i] : self.ends[i]]
172+
self._fuse()
188173
return
189174

190175

191-
class CustomBMMWeight(CustomMMWeight):
192-
def __init__(
193-
self,
194-
weight_name,
195-
data_type,
196-
split_n_embed,
197-
bias_name=None,
198-
wait_fuse=False,
199-
disable_tp=False,
200-
custom_load=None,
201-
custom_fuse=None,
202-
):
203-
super().__init__(
204-
weight_name,
205-
data_type,
206-
split_n_embed,
207-
bias_name,
208-
wait_fuse=wait_fuse,
209-
disable_tp=disable_tp,
210-
custom_load=custom_load,
211-
custom_fuse=custom_fuse,
212-
)
176+
class BMMWeightTpl(BaseWeightTpl):
177+
def __init__(self, data_type):
178+
super().__init__()
179+
self.data_type_ = data_type
180+
self.quant_method = None
181+
self.weight = None
182+
self.bias = None
213183

214184
def set_quant_method(self, quant_method):
215-
return
216-
raise NotImplementedError("BMM does not currently support quantification")
185+
self.quant_method = None
217186

218187
def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
219188
if self.quant_method is not None:
@@ -230,8 +199,52 @@ def bmm(self, input_tensor, out=None, use_custom_tensor_mananger=True):
230199
return torch.bmm(input_tensor, self.weight, out=out)
231200
return torch.addbmm(self.bias, input_tensor, self.weight, out=out)
232201

233-
def post_load_weights(self, weight):
234-
if self.quant_method is not None:
235-
self.weight = self.quant_method.quantize(weight.cuda(self.tp_rank_))
236-
return
237-
self.weight = weight.cuda(self.tp_rank_)
202+
def _post_load_weights(self):
203+
self.weight = self.weight.cuda(self.tp_rank_)
204+
205+
206+
class BMMWeight(BMMWeightTpl):
207+
def __init__(self, weight_name, data_type, split_n_embed, bias_name=None):
208+
super().__init__(data_type)
209+
self.start = split_n_embed * self.tp_rank_
210+
self.end = split_n_embed * (self.tp_rank_ + 1)
211+
self.weight_name = weight_name
212+
self.bias_name = bias_name
213+
214+
def verify_load(self):
215+
load_ok = True
216+
# Verify weight. The weight must be not None.
217+
load_ok = load_ok and self.weight is not None
218+
# Verify bias. If bias_name is set, it must be not None.
219+
if self.bias_name is not None:
220+
load_ok = load_ok and self.bias is not None
221+
return load_ok
222+
223+
224+
class ROWBMMWeight(BMMWeight):
225+
load_hf_weights = ROWMMWeight.load_hf_weights
226+
227+
def __init__(
228+
self,
229+
weight_name,
230+
data_type,
231+
split_n_embed,
232+
bias_name=None,
233+
):
234+
super().__init__(weight_name, data_type, split_n_embed, bias_name)
235+
236+
237+
class COLBMMWeight(BMMWeight):
238+
load_hf_weights = COLMMWeight.load_hf_weights
239+
240+
def __init__(
241+
self,
242+
weight_name,
243+
data_type,
244+
split_n_embed,
245+
bias_name=None,
246+
):
247+
super().__init__(weight_name, data_type, split_n_embed, bias_name)
248+
249+
def _post_load_weights(self):
250+
self.weight = self.weight.transpose(0, 1).cuda(self.tp_rank_)

lightllm/common/basemodel/layer_weights/meta_weights/norm_weight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33

44
class NormWeight(BaseWeightTpl):
5-
def __init__(self, weight_name, data_type, bias_name):
5+
def __init__(self, weight_name, data_type, bias_name=None):
66
super().__init__()
77
self.weight_name = weight_name
88
self.bias_name = bias_name

0 commit comments

Comments
 (0)