Skip to content

Commit 79d1d84

Browse files
satreysaSumedha Atreysakunal-vaishnavi
authored
Adding q_norm, k_norm support for quantized models (microsoft#1483)
This PR adds support for q_norm and k_norm layers in quantized models within the OGA framework. Specifically, it introduces the following enhancements to **quantized_model.py**: - Initializes `q_norm` and `k_norm` as Tensor modules within the `QuantizedAttention` and `QuantizedDecoder` classes. - Maps the corresponding weights and biases for `q_norm` and `k_norm` to the initialized tensor modules during model loading. This enables accurate handling of models that include` q_norm` and `k_norm` as part of their quantized attention mechanisms, improving compatibility with newer quantized LLMs. **Changes Made:** - Added initialization of `q_norm` and `k_norm` as `Tensor` modules in: - `QuantizedAttention` class - `QuantizedDecoder` class - Mapped corresponding weights and biases from the model to these tensor modules during model loading - Ensured consistency with the existing quantized attention initialization flow **Reviewer Notes:** - Please verify: - The initialization logic aligns with the handling of other norm layers (e.g., `qkv_norm`) - No side effects are introduced for models that do not contain `q_norm` or `k_norm` - Tested locally with a quantized model(Qwen3 models) containing `q_norm`/`k_norm`, but additional validation with other architectures is welcome --------- Co-authored-by: Sumedha Atreysa <[email protected]> Co-authored-by: kunal-vaishnavi <[email protected]>
1 parent 7192deb commit 79d1d84

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

src/python/py/models/quantized_model.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def __init__(self, bits, group_size):
6363
self.v_proj = QuantizedTensorModule(bits, group_size)
6464
self.o_proj = QuantizedTensorModule(bits, group_size)
6565
self.rotary_emb = TensorModule()
66-
66+
self.k_norm = TensorModule()
67+
self.q_norm = TensorModule()
6768

6869
class QuantizedMLP:
6970
def __init__(self, bits, group_size):
@@ -149,6 +150,7 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
149150

150151
# Map weights and biases of norm, attention, and feed-forward network
151152
# Graph order is input_layernorm --> q_proj/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj
153+
# If model uses q_norm and k_norm, graph order is input_layernorm --> q_norm/q_proj/k_norm/k_proj/v_proj --> o_proj --> post_attention_layernorm --> gate_proj/up_proj --> down_proj
152154
if bool(re.match(r"^model.layers\.\d+\.input_layernorm\.weight$", name)):
153155
# model.layers.layer_id.input_layernorm.weight
154156
module.input_layernorm.weight = tensor
@@ -177,6 +179,12 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
177179
elif bool(re.match(r"^model.layers\.\d+\.self_attn.q_proj\.bias$", name)):
178180
# model.layers.layer_id.self_attn.q_proj.bias
179181
module.self_attn.q_proj.bias = tensor
182+
elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.q_norm\.weight$", name)):
183+
# model.layers.layer_id.self_attn.q_norm.weight
184+
module.self_attn.q_norm.weight = tensor
185+
elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.q_norm\.bias$", name)):
186+
# model.layers.layer_id.self_attn.q_norm.bias
187+
module.self_attn.q_norm.bias = tensor
180188
elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.q?weight$", name)):
181189
# model.layers.layer_id.self_attn.k_proj.qweight
182190
# model.layers.layer_id.self_attn.k_proj.weight
@@ -195,6 +203,12 @@ def __init__(self, quant_type, input_path, quant_attrs, q_size, kv_size, interme
195203
elif bool(re.match(r"^model.layers\.\d+\.self_attn.k_proj\.bias$", name)):
196204
# model.layers.layer_id.self_attn.k_proj.bias
197205
module.self_attn.k_proj.bias = tensor
206+
elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.k_norm\.weight$", name)):
207+
# model.layers.layer_id.self_attn.k_norm.weight
208+
module.self_attn.k_norm.weight = tensor
209+
elif bool(re.match(r"^model\.layers\.\d+\.self_attn\.k_norm\.bias$", name)):
210+
# model.layers.layer_id.self_attn.k_norm.bias
211+
module.self_attn.k_norm.bias = tensor
198212
elif bool(re.match(r"^model.layers\.\d+\.self_attn.v_proj\.q?weight$", name)):
199213
# model.layers.layer_id.self_attn.v_proj.qweight
200214
# model.layers.layer_id.self_attn.v_proj.weight

0 commit comments

Comments
 (0)