Skip to content

Commit e2f02af

Browse files
committed
fix blockwise fp8 for cc
1 parent 72e3c75 commit e2f02af

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,27 @@ def _load_kb(self, kv_b_proj_):
132132
k_b_proj_ = kv_b_proj_.view(self.num_attention_heads, self.qk_nope_head_dim * 2, self.kv_lora_rank)[
133133
:, : self.qk_nope_head_dim, :
134134
]
135-
return k_b_proj_.contiguous().to(self.data_type_)
135+
return k_b_proj_.contiguous().to(kv_b_proj_.dtype)
136136

137137
def _load_kb_scale(self, kv_b_proj_, block_size):
138138
k_b_proj_scale_ = kv_b_proj_.view(
139139
self.num_attention_heads, self.qk_nope_head_dim * 2 // block_size, self.kv_lora_rank // block_size
140140
)[:, : self.qk_nope_head_dim // block_size, :]
141-
return k_b_proj_scale_.contiguous().to(self.data_type_)
141+
return k_b_proj_scale_.contiguous().to(kv_b_proj_.dtype)
142142

143143
def _load_vb(self, kv_b_proj_):
144144
v_b_proj_ = kv_b_proj_.T.view(self.kv_lora_rank, self.num_attention_heads, self.qk_nope_head_dim * 2,)[
145145
:, :, self.qk_nope_head_dim :
146146
].transpose(0, 1)
147-
return v_b_proj_.contiguous().to(self.data_type_)
147+
return v_b_proj_.contiguous().to(kv_b_proj_.dtype)
148148

149149
def _load_vb_scale(self, kv_b_proj_scale_, block_size):
150150
v_b_proj_scale_ = kv_b_proj_scale_.T.view(
151151
self.kv_lora_rank // block_size,
152152
self.num_attention_heads,
153153
self.qk_nope_head_dim * 2 // block_size,
154154
)[:, :, self.qk_nope_head_dim // block_size :].transpose(0, 1)
155-
return v_b_proj_scale_.contiguous().to(self.data_type_)
155+
return v_b_proj_scale_.contiguous().to(kv_b_proj_scale_.dtype)
156156

157157
def load_hf_weights(self, weights):
158158
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
@@ -166,6 +166,8 @@ def load_hf_weights(self, weights):
166166
weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"] = (
167167
self._load_vb(kv_b_proj_).transpose(0, 1).reshape(self.kv_lora_rank, -1).transpose(0, 1).contiguous()
168168
)
169+
# print( weights[f"model.layers.{self.layer_num_}.self_attn.cc_v_b_proj.weight"].dtype)
170+
# print( weights[f"model.layers.{self.layer_num_}.self_attn.v_b_proj.weight"].dtype)
169171

170172
if (
171173
self.quant_cfg.quantized_weight

0 commit comments

Comments
 (0)