Skip to content

Commit ca7d2f9

Browse files
authored
deepseekv3 cc mode fixed (#719)
Co-authored-by: shihaobai <[email protected]>
1 parent 743ddc3 commit ca7d2f9

File tree

2 files changed

+45
-10
lines changed

2 files changed

+45
-10
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
154154
)
155155

156156
# CC
157-
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
157+
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous()
158158
k_nope = self.alloc_tensor(
159159
[compressed_kv.shape[0], self.tp_q_head_num_, self.qk_nope_head_dim],
160160
dtype=compressed_kv.dtype,
@@ -163,10 +163,8 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
163163
k_nope.shape,
164164
dtype=compressed_kv.dtype,
165165
)
166-
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank).T
167-
wv = layer_weight.v_b_proj_.weight.transpose(0, 1).reshape(layer_weight.kv_lora_rank, -1)
168-
torch.mm(compressed_kv, wk, out=k_nope.reshape(compressed_kv.shape[0], -1))
169-
torch.mm(compressed_kv, wv, out=v.reshape(compressed_kv.shape[0], -1))
166+
layer_weight.cc_k_b_proj_.mm(compressed_kv, out=k_nope.reshape(compressed_kv.shape[0], -1))
167+
layer_weight.cc_v_b_proj_.mm(compressed_kv, out=v.reshape(compressed_kv.shape[0], -1))
170168
return k_nope, k_rope, v
171169

172170
def _context_attention_kernel_with_CC(

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def fuse_vb_o(self, layer_weight):
6666
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
6767
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
6868
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
69+
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
6970
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
7071
return
7172

@@ -132,34 +133,44 @@ def _load_kb(self, kv_b_proj_):
132133
k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[
133134
:, : self.qk_nope_head_dim, :
134135
]
135-
return k_b_proj_.contiguous().to(self.data_type_)
136+
return k_b_proj_.contiguous().to(kv_b_proj_.dtype)
136137

137138
def _load_kb_scale(self, kv_b_proj_, block_size):
138139
k_b_proj_scale_ = kv_b_proj_.view(
139140
self.num_attention_heads, self.qk_nope_head_dim * 2 // block_size, self.kv_lora_rank // block_size
140141
)[:, : self.qk_nope_head_dim // block_size, :]
141-
return k_b_proj_scale_.contiguous().to(self.data_type_)
142+
return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype)
142143

143144
def _load_vb(self, kv_b_proj_):
144145
v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[
145146
:, :, self.qk_nope_head_dim :
146147
].transpose(0, 1)
147-
return v_b_proj_.contiguous().to(self.data_type_)
148+
return v_b_proj_.contiguous().to(kv_b_proj_.dtype)
148149

149150
def _load_vb_scale(self, kv_b_proj_scale_, block_size):
150151
v_b_proj_scale_ = kv_b_proj_scale_.T.view(
151152
self.kv_lora_rank // block_size,
152153
self.num_attention_heads,
153154
self.qk_nope_head_dim * 2 // block_size,
154155
)[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1)
155-
return v_b_proj_scale_.contiguous().to(self.data_type_)
156+
return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype)
156157

157158
def load_hf_weights(self, weights):
158159
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
159160
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
160161
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
161162
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)
162-
163+
if self.enable_cc_method:
164+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight"] = self._load_kb(
165+
kv_b_proj_
166+
).reshape(-1, self.kv_lora_rank)
167+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"] = (
168+
self._load_vb(kv_b_proj_)
169+
.transpose(0, 1)
170+
.reshape(self.kv_lora_rank, -1)
171+
.transpose(0, 1)
172+
.contiguous()
173+
)
163174
if (
164175
self.quant_cfg.quantized_weight
165176
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
@@ -177,6 +188,17 @@ def load_hf_weights(self, weights):
177188
weights[
178189
f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + self.weight_scale_suffix
179190
] = self._load_vb_scale(kv_b_proj_scale_, block_size)
191+
if self.enable_cc_method:
192+
weights[
193+
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj." + self.weight_scale_suffix
194+
] = self._load_kb_scale(kv_b_proj_scale_, block_size).reshape(-1, self.kv_lora_rank // block_size)
195+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj." + self.weight_scale_suffix] = (
196+
self._load_vb_scale(kv_b_proj_scale_, block_size)
197+
.transpose(0, 1)
198+
.reshape(self.kv_lora_rank // block_size, -1)
199+
.transpose(0, 1)
200+
.contiguous()
201+
)
180202

181203
return super().load_hf_weights(weights)
182204

@@ -236,6 +258,21 @@ def _init_qkvo(self):
236258
weight_scale_suffix=self.weight_scale_suffix,
237259
act_scale_suffix=self.act_scale_suffix,
238260
)
261+
if self.enable_cc_method:
262+
self.cc_k_b_proj_ = ROWMMWeight(
263+
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight",
264+
self.data_type_,
265+
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
266+
weight_scale_suffix=self.weight_scale_suffix,
267+
act_scale_suffix=self.act_scale_suffix,
268+
)
269+
self.cc_v_b_proj_ = ROWMMWeight(
270+
f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight",
271+
self.data_type_,
272+
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
273+
weight_scale_suffix=self.weight_scale_suffix,
274+
act_scale_suffix=self.act_scale_suffix,
275+
)
239276

240277
self.o_weight_ = COLMMWeight(
241278
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight",

0 commit comments

Comments
 (0)