Skip to content

Commit b60ac4f

Browse files
fp8 scale repeat for qwen3 (#879)
Co-authored-by: baishihao <[email protected]> Co-authored-by: wangzaijun <[email protected]>
1 parent c46810b commit b60ac4f

File tree

2 files changed

+18
-2
lines changed

2 files changed

+18
-2
lines changed

lightllm/models/qwen3_moe/layer_weights/transformer_layer_weight.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,21 @@ def _repeat_weight(self, name, weights):
6060
if name in weights:
6161
weights[name] = (
6262
weights[name]
63-
.reshape(self.network_config_["num_key_value_heads"], self.head_dim, -1)
63+
.reshape(self.network_config_["num_key_value_heads"], -1, weights[name].shape[1])
6464
.unsqueeze(1)
6565
.repeat(repeat_params)
66-
.reshape(self.network_config_["num_key_value_heads"] * self.head_dim * repeat_size, -1)
66+
.reshape(-1, weights[name].shape[1])
6767
)
6868

6969
def load_hf_weights(self, weights):
7070
self._repeat_weight(self._k_weight_name, weights)
7171
self._repeat_weight(self._v_weight_name, weights)
72+
kv_b_quant_method = self.quant_cfg.get_quant_method(self.layer_num_, "kv_b_proj")
73+
if self.quant_cfg.quantized_weight:
74+
_k_scale_weight_name = self._k_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix)
75+
self._repeat_weight(_k_scale_weight_name, weights)
76+
_v_scale_weight_name = self._v_weight_name.replace("weight", kv_b_quant_method.weight_scale_suffix)
77+
self._repeat_weight(_v_scale_weight_name, weights)
7278
return super().load_hf_weights(weights)
7379

7480
def _init_weight(self):

lightllm/models/qwen3_moe/model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,16 @@ def _verify_params(self):
2626
assert self.config["num_attention_heads"] % self.tp_world_size_ == 0
2727
return
2828

29+
def _init_some_value(self):
30+
# Dealing with head_dim_!=n_embed // num_attention_heads scenarios, such as mistral 13B
31+
head_dim_ = self.config["n_embed"] // self.config["num_attention_heads"]
32+
self.head_dim_ = self.config.get("head_dim", head_dim_)
33+
self.tp_k_head_num_ = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1)
34+
self.tp_v_head_num_ = self.tp_k_head_num_
35+
self.layers_num = self.config["n_layer"]
36+
self.vocab_size = self.config["vocab_size"]
37+
return
38+
2939
def _init_mem_manager(self):
3040
head_dim_ = self.config["hidden_size"] // self.config["num_attention_heads"]
3141
head_dim_ = self.config.get("head_dim", head_dim_)

0 commit comments

Comments
 (0)