Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
)

# CC
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank)
compressed_kv = compressed_kv.view(-1, layer_weight.kv_lora_rank).contiguous()
k_nope = self.alloc_tensor(
[compressed_kv.shape[0], self.tp_q_head_num_, self.qk_nope_head_dim],
dtype=compressed_kv.dtype,
Expand All @@ -163,10 +163,8 @@ def _decompress_kv(self, kv, infer_state: Deepseek2InferStateInfo, layer_weight:
k_nope.shape,
dtype=compressed_kv.dtype,
)
wk = layer_weight.k_b_proj_.weight.view(-1, layer_weight.kv_lora_rank).T
wv = layer_weight.v_b_proj_.weight.transpose(0, 1).reshape(layer_weight.kv_lora_rank, -1)
torch.mm(compressed_kv, wk, out=k_nope.reshape(compressed_kv.shape[0], -1))
torch.mm(compressed_kv, wv, out=v.reshape(compressed_kv.shape[0], -1))
layer_weight.cc_k_b_proj_.mm(compressed_kv, out=k_nope.reshape(compressed_kv.shape[0], -1))
layer_weight.cc_v_b_proj_.mm(compressed_kv, out=v.reshape(compressed_kv.shape[0], -1))
return k_nope, k_rope, v

def _context_attention_kernel_with_CC(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def fuse_vb_o(self, layer_weight):
class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[], quant_cfg=None):
self.enable_dp = os.getenv("ENABLE_DP", "0").upper() in ["ON", "TRUE", "1"]
self.enable_cc_method = os.getenv("ENABLE_CC_METHOD", "False").upper() in ["ON", "TRUE", "1"]
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode, quant_cfg)
return

Expand Down Expand Up @@ -132,34 +133,44 @@ def _load_kb(self, kv_b_proj_):
k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[
:, : self.qk_nope_head_dim, :
]
return k_b_proj_.contiguous().to(self.data_type_)
return k_b_proj_.contiguous().to(kv_b_proj_.dtype)

def _load_kb_scale(self, kv_b_proj_, block_size):
k_b_proj_scale_ = kv_b_proj_.view(
self.num_attention_heads, self.qk_nope_head_dim * 2 // block_size, self.kv_lora_rank // block_size
)[:, : self.qk_nope_head_dim // block_size, :]
return k_b_proj_scale_.contiguous().to(self.data_type_)
return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype)

def _load_vb(self, kv_b_proj_):
v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[
:, :, self.qk_nope_head_dim :
].transpose(0, 1)
return v_b_proj_.contiguous().to(self.data_type_)
return v_b_proj_.contiguous().to(kv_b_proj_.dtype)

def _load_vb_scale(self, kv_b_proj_scale_, block_size):
v_b_proj_scale_ = kv_b_proj_scale_.T.view(
self.kv_lora_rank // block_size,
self.num_attention_heads,
self.qk_nope_head_dim * 2 // block_size,
)[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1)
return v_b_proj_scale_.contiguous().to(self.data_type_)
return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype)

def load_hf_weights(self, weights):
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
weights[f"model.layers.{self.layer_num_}.self_attn.k_b_proj.weight"] = self._load_kb(kv_b_proj_)
weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"] = self._load_vb(kv_b_proj_)

if self.enable_cc_method:
weights[f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight"] = self._load_kb(
kv_b_proj_
).reshape(-1, self.kv_lora_rank)
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"] = (
self._load_vb(kv_b_proj_)
.transpose(0, 1)
.reshape(self.kv_lora_rank, -1)
.transpose(0, 1)
.contiguous()
)
if (
self.quant_cfg.quantized_weight
and f"model.layers.{self.layer_num_}.self_attn.kv_b_proj." + self.weight_scale_suffix in weights
Expand All @@ -177,6 +188,17 @@ def load_hf_weights(self, weights):
weights[
f"model.layers.{self.layer_num_}.self_attn.v_b_proj." + self.weight_scale_suffix
] = self._load_vb_scale(kv_b_proj_scale_, block_size)
if self.enable_cc_method:
weights[
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj." + self.weight_scale_suffix
] = self._load_kb_scale(kv_b_proj_scale_, block_size).reshape(-1, self.kv_lora_rank // block_size)
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj." + self.weight_scale_suffix] = (
self._load_vb_scale(kv_b_proj_scale_, block_size)
.transpose(0, 1)
.reshape(self.kv_lora_rank // block_size, -1)
.transpose(0, 1)
.contiguous()
)

return super().load_hf_weights(weights)

Expand Down Expand Up @@ -236,6 +258,21 @@ def _init_qkvo(self):
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)
if self.enable_cc_method:
self.cc_k_b_proj_ = ROWMMWeight(
f"model.layers.{self.layer_num_}.self_attn.cc_k_b_proj.weight",
self.data_type_,
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)
self.cc_v_b_proj_ = ROWMMWeight(
f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight",
self.data_type_,
split_n_embed=self.tp_q_head_num_ * self.qk_nope_head_dim,
weight_scale_suffix=self.weight_scale_suffix,
act_scale_suffix=self.act_scale_suffix,
)

self.o_weight_ = COLMMWeight(
f"model.layers.{self.layer_num_}.self_attn.o_proj.weight",
Expand Down