Skip to content

Commit 270ce08

Browse files
committed
enable cc for weight
1 parent c27afec commit 270ce08

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ def fuse_vb_o(self, layer_weight):
6666
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
6767
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
6868
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
69+
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
6970
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
7071
return
7172

@@ -159,13 +160,17 @@ def load_hf_weights(self, weights):
159160
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
160161
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
161162
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)
162-
163-
weights[f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight"] = self._load_kb(kv_b_proj_).reshape(
164-
-1, self.kv_lora_rank
165-
)
166-
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"] = (
167-
self._load_vb(kv_b_proj_).transpose(0, 1).reshape(self.kv_lora_rank, -1).transpose(0, 1).contiguous()
168-
)
163+
if self.enable_cc_method:
164+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight"] = self._load_kb(
165+
kv_b_proj_
166+
).reshape(-1, self.kv_lora_rank)
167+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"] = (
168+
self._load_vb(kv_b_proj_)
169+
.transpose(0, 1)
170+
.reshape(self.kv_lora_rank, -1)
171+
.transpose(0, 1)
172+
.contiguous()
173+
)
169174
if (
170175
self.quant_cfg.quantized_weight
171176
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
@@ -183,17 +188,17 @@ def load_hf_weights(self, weights):
183188
weights[
184189
f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + self.weight_scale_suffix
185190
] = self._load_vb_scale(kv_b_proj_scale_, block_size)
186-
187-
weights[
188-
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj." + self.weight_scale_suffix
189-
] = self._load_kb_scale(kv_b_proj_scale_, block_size).reshape(-1, self.kv_lora_rank // block_size)
190-
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj." + self.weight_scale_suffix] = (
191-
self._load_vb_scale(kv_b_proj_scale_, block_size)
192-
.transpose(0, 1)
193-
.reshape(self.kv_lora_rank // block_size, -1)
194-
.transpose(0, 1)
195-
.contiguous()
196-
)
191+
if self.enable_cc_method:
192+
weights[
193+
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj." + self.weight_scale_suffix
194+
] = self._load_kb_scale(kv_b_proj_scale_, block_size).reshape(-1, self.kv_lora_rank // block_size)
195+
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj." + self.weight_scale_suffix] = (
196+
self._load_vb_scale(kv_b_proj_scale_, block_size)
197+
.transpose(0, 1)
198+
.reshape(self.kv_lora_rank // block_size, -1)
199+
.transpose(0, 1)
200+
.contiguous()
201+
)
197202

198203
return super().load_hf_weights(weights)
199204

@@ -253,21 +258,21 @@ def _init_qkvo(self):
253258
weight_scale_suffix=self.weight_scale_suffix,
254259
act_scale_suffix=self.act_scale_suffix,
255260
)
256-
257-
self.cc_k_b_proj_ = ROWMMWeight(
258-
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight",
259-
self.data_type_,
260-
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
261-
weight_scale_suffix=self.weight_scale_suffix,
262-
act_scale_suffix=self.act_scale_suffix,
263-
)
264-
self.cc_v_b_proj_ = ROWMMWeight(
265-
f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight",
266-
self.data_type_,
267-
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
268-
weight_scale_suffix=self.weight_scale_suffix,
269-
act_scale_suffix=self.act_scale_suffix,
270-
)
261+
if self.enable_cc_method:
262+
self.cc_k_b_proj_ = ROWMMWeight(
263+
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight",
264+
self.data_type_,
265+
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
266+
weight_scale_suffix=self.weight_scale_suffix,
267+
act_scale_suffix=self.act_scale_suffix,
268+
)
269+
self.cc_v_b_proj_ = ROWMMWeight(
270+
f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight",
271+
self.data_type_,
272+
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
273+
weight_scale_suffix=self.weight_scale_suffix,
274+
act_scale_suffix=self.act_scale_suffix,
275+
)
271276

272277
self.o_weight_ = COLMMWeight(
273278
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight",

0 commit comments

Comments
 (0)