-
Notifications
You must be signed in to change notification settings - Fork 190
[OMNIML-2932] Fusing pre_quant_scale for NVFP4 AWQ #421
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
9bd8e41
13042fa
d1c5d19
e599d43
26f2eb7
c5d9682
6dd1b87
a5a6e39
6020e94
9339223
a591330
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames | |
|
|
||
| if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): | ||
| return QUANTIZATION_NVFP4_AWQ | ||
| if getattr(layer, "fused_with_layernorm", False): | ||
| if getattr(layer, "fused_with_prequant", False): | ||
| return QUANTIZATION_NVFP4_AWQ | ||
| assert input_quantizer is not None, ( | ||
| f"input_quantizer is None for {quantizer_attr_names}" | ||
|
|
@@ -923,18 +923,149 @@ def all_items_same(item_list): | |
| return all(x == item_list[0] for x in item_list) | ||
|
|
||
|
|
||
| # Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) | ||
| PQS_FUSE_MODULE_MAPPING = [ | ||
| # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension | ||
| # Mathematical equivalence: | ||
| # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T | ||
| # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T | ||
| (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), | ||
| # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension | ||
| # Mathematical equivalence: | ||
| # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T | ||
| # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T | ||
| (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")), | ||
| ] | ||
|
|
||
|
|
||
| def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False): | ||
| """Fuse pre_quant_scale to the linear weights if possible. | ||
|
|
||
| For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that | ||
| the results are mathematically equivalent to the following:: | ||
|
|
||
| out_proj.input = (attn_weights @ v_proj.output) | ||
| out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight | ||
| = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight | ||
|
|
||
| For GQA/MQA models where v_proj output dimension < o_proj input dimension, | ||
| the pre_quant_scale is averaged across the repeated head groups and then the | ||
| o_proj's pre_quant_scale is updated to maintain mathematical equivalence. | ||
|
|
||
| Args: | ||
| model: The model to fuse pre_quant_scale to. | ||
| fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale | ||
| and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy | ||
| drop. | ||
|
|
||
| Note: | ||
| Fuse_grouped_heads is useful for GQA/MQA models but may lead to accuracy drop. | ||
| """ | ||
| # Fuse pre_quant_scale to the linear weights | ||
| for _, module in model.named_modules(): | ||
| for module_map in PQS_FUSE_MODULE_MAPPING: | ||
| target_module_list = module_map[0] | ||
| linear_pair = module_map[1] | ||
| if any(module_name in type(module).__name__ for module_name in target_module_list): | ||
| linear_fuse_into = module.get_submodule(linear_pair[0]) | ||
| linear_pqs_from = module.get_submodule(linear_pair[1]) | ||
| if hasattr(linear_pqs_from, "input_quantizer") and hasattr( | ||
| linear_pqs_from.input_quantizer, "_pre_quant_scale" | ||
| ): | ||
| pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale | ||
|
|
||
| # for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups | ||
| if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]: | ||
| if ( | ||
| not fuse_grouped_heads | ||
| or "attention" not in type(module).__name__.lower() | ||
| ): | ||
| warn( | ||
| f"Skipping pattern fuse prequant for {type(module).__name__}" | ||
| f"pqs dim {pre_quant_scale.numel()} != out_ch dim {linear_fuse_into.weight.shape[-2]}" | ||
| ) | ||
| continue | ||
| config = module.config | ||
| num_kv_heads = config.num_key_value_heads | ||
| kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads | ||
| n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim | ||
|
|
||
| # Reshape:(num_kv_heads, n_rep, kv_head_dim) | ||
| averaged_scale = pre_quant_scale.view( | ||
| num_kv_heads, n_rep, kv_head_dim | ||
| ).mean(dim=1) | ||
|
|
||
| # To update o_proj, we need to repeat back to original shape | ||
| repeated_scale = ( | ||
| averaged_scale.unsqueeze(1) | ||
| .expand(num_kv_heads, n_rep, kv_head_dim) | ||
| .reshape(-1) | ||
| ) | ||
|
|
||
| def _update_pre_quant_scale(module, new_pre_quant_scale): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we merge duplicated code with line 1090? |
||
| old_pre_quant_scale = module.input_quantizer._pre_quant_scale | ||
| module.weight = nn.Parameter( | ||
| module.weight | ||
| * old_pre_quant_scale.to( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we want to cast to fp32 for this manipulation? |
||
| dtype=module.weight.dtype, device=module.weight.device | ||
| ) | ||
| / new_pre_quant_scale.to( | ||
| dtype=module.weight.dtype, device=module.weight.device | ||
| ) | ||
| ) | ||
| module.input_quantizer.pre_quant_scale = new_pre_quant_scale | ||
|
|
||
| # Redo weights collection | ||
| module.weight_quantizer.reset_amax() | ||
| enable_stats_collection(module.weight_quantizer) | ||
| module.weight_quantizer(module.weight) | ||
| finish_stats_collection(module.weight_quantizer) | ||
|
|
||
| # Update o_proj's pre_quant_scale | ||
| _update_pre_quant_scale(linear_pqs_from, repeated_scale) | ||
|
|
||
| # Use averaged scale (flattened) for v_proj fusion | ||
| pre_quant_scale = averaged_scale.reshape(-1) | ||
|
|
||
| # Fuse the pre_quant_scale to weight | ||
| linear_fuse_into.weight = torch.nn.Parameter( | ||
| linear_fuse_into.weight * pre_quant_scale.view(-1, 1) | ||
| ) | ||
| if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None: | ||
| linear_fuse_into.bias = torch.nn.Parameter( | ||
| linear_fuse_into.bias * pre_quant_scale | ||
| ) | ||
|
|
||
| delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale") | ||
| setattr(linear_pqs_from, "fused_with_prequant", True) | ||
|
|
||
|
|
||
| def fuse_prequant_layernorm( | ||
| layernorm_module: torch.nn.Module, | ||
| modules: list[torch.Tensor], | ||
| ): | ||
| """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.""" | ||
| """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted. | ||
|
|
||
| original: | ||
| layernorm_output = (normalization(input) * weight) + bias | ||
| layernorm_output_scaled = layernorm_output * pre_quant_scale | ||
|
|
||
| fused: | ||
| fused_weight = weight * avg_pre_quant_scale | ||
| fused_bias = bias * avg_pre_quant_scale | ||
| layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias | ||
| """ | ||
| layernorm_module.weight = torch.nn.Parameter( | ||
| layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale") | ||
| ) | ||
| if hasattr(layernorm_module, "bias"): | ||
| layernorm_module.bias = torch.nn.Parameter( | ||
| layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale") | ||
| ) | ||
| # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm | ||
| for module in modules: | ||
| delattr(module.input_quantizer, "_pre_quant_scale") | ||
| setattr(module, "fused_with_layernorm", True) | ||
| setattr(module, "fused_with_prequant", True) | ||
|
|
||
|
|
||
| def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,193 @@ | ||
| # SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| pytest.importorskip("transformers") | ||
|
|
||
| from transformers import LlamaConfig, LlamaForCausalLM | ||
|
|
||
| import modelopt.torch.quantization as mtq | ||
| from modelopt.torch.export.quant_utils import fuse_prequant_to_linear | ||
|
|
||
|
|
||
| def get_tiny_llama(attention_heads=4, key_value_heads=4): | ||
| """Create a tiny Llama model for testing.""" | ||
| config = LlamaConfig( | ||
| hidden_size=64, | ||
| intermediate_size=128, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=attention_heads, | ||
| num_key_value_heads=key_value_heads, | ||
| max_position_embeddings=128, | ||
| vocab_size=256, | ||
| ) | ||
| return LlamaForCausalLM(config) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "quant_config", | ||
| [ | ||
| mtq.INT4_AWQ_CFG, | ||
| mtq.NVFP4_AWQ_LITE_CFG, | ||
| ], | ||
| ) | ||
| @pytest.mark.parametrize( | ||
| "attention_kv_heads_pair", | ||
| [ | ||
| (4, 4), # MHA | ||
| (4, 2), # GQA | ||
| (4, 1), # MQA | ||
| ], | ||
| ) | ||
| def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): | ||
| """Test pattern_fuse_prequant on modules from a tiny Llama model.""" | ||
| model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda") | ||
|
|
||
| # Quantize the model | ||
| dummy_input = torch.randint(0, 256, (1, 16), device="cuda") | ||
| mtq.quantize(model, quant_config, lambda m: m(dummy_input)) | ||
|
|
||
| # Run forward pass before fusion | ||
| model.eval() | ||
| with torch.no_grad(): | ||
| output_before_fuse = model(dummy_input) | ||
|
|
||
| traget_module_name_list = [ | ||
| "model.layers.0.self_attn.o_proj", | ||
| "model.layers.0.mlp.down_proj", | ||
| "model.layers.1.self_attn.o_proj", | ||
| "model.layers.1.mlp.down_proj", | ||
| ] | ||
|
|
||
| # Apply fusion | ||
| fuse_prequant_to_linear(model, fuse_grouped_heads=True) | ||
|
|
||
| # Check if pre_quant_scale and fused_with_prequant flag are removed correctly | ||
| for target_module_name in traget_module_name_list: | ||
| target_module = model.get_submodule(target_module_name) | ||
|
|
||
| # Verify pre_quant_scale was removed | ||
| assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), ( | ||
| f"{target_module_name}: pre_quant_scale should be removed after fusion" | ||
| ) | ||
|
|
||
| # Verify fused_with_prequant flag was set | ||
| assert ( | ||
| hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant | ||
| ), f"{target_module_name}: fused_with_prequant flag should be set" | ||
|
|
||
| # Verify output is close to the original output | ||
| with torch.no_grad(): | ||
| output_after_fuse = model(dummy_input) | ||
| # There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights | ||
| assert torch.allclose( | ||
| output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 | ||
| ), "Output should be the same before and after fusion" | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "quant_config", | ||
| [ | ||
| mtq.INT4_AWQ_CFG, | ||
| mtq.NVFP4_AWQ_LITE_CFG, | ||
| ], | ||
| ) | ||
| def test_pattern_fuse_prequant_moe(quant_config): | ||
| """Test pattern_fuse_prequant on Qwen3 MoE sparse MLP.""" | ||
| pytest.importorskip("transformers") | ||
| from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM | ||
|
|
||
| # Create a tiny Qwen3MoE model for testing | ||
| config = Qwen3MoeConfig( | ||
| hidden_size=128, | ||
| intermediate_size=256, | ||
| moe_intermediate_size=256, | ||
| num_hidden_layers=2, | ||
| num_attention_heads=4, | ||
| num_key_value_heads=4, | ||
| num_experts=4, | ||
| num_experts_per_tok=2, | ||
| max_position_embeddings=128, | ||
| vocab_size=256, | ||
| shared_expert_intermediate_size=256, | ||
| ) | ||
| model = Qwen3MoeForCausalLM(config).to("cuda") | ||
|
|
||
| # Quantize the model | ||
| dummy_input = torch.randint(0, 256, (1, 16), device="cuda") | ||
| mtq.quantize(model, quant_config, lambda m: m(dummy_input)) | ||
|
|
||
| # Collect MoE expert modules to verify (down_proj should be fused) | ||
| moe_down_proj_modules = [] | ||
| moe_gate_proj_modules = [] | ||
| moe_up_proj_modules = [] | ||
| for name, module in model.named_modules(): | ||
| if "mlp" in name and "experts" in name: | ||
| if "gate_proj" in name and not any(x in name for x in ["weight", "quantizer"]): | ||
| moe_gate_proj_modules.append((name, module)) | ||
| elif "down_proj" in name and not any(x in name for x in ["weight", "quantizer"]): | ||
| moe_down_proj_modules.append((name, module)) | ||
| elif "up_proj" in name and not any(x in name for x in ["weight", "quantizer"]): | ||
| moe_up_proj_modules.append((name, module)) | ||
|
|
||
| # Verify experts have pre_quant_scale before fusion | ||
| for name, module in moe_gate_proj_modules: | ||
| if hasattr(module, "input_quantizer"): | ||
| assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( | ||
| f"{name}: gate_proj should have pre_quant_scale before fusion" | ||
| ) | ||
|
|
||
| for name, module in moe_up_proj_modules: | ||
| if hasattr(module, "input_quantizer"): | ||
| assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( | ||
| f"{name}: up_proj should have pre_quant_scale before fusion" | ||
| ) | ||
|
|
||
| for name, module in moe_down_proj_modules: | ||
| if hasattr(module, "input_quantizer"): | ||
| assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( | ||
| f"{name}: down_proj should have pre_quant_scale before fusion" | ||
| ) | ||
|
|
||
| # Run forward pass before fusion | ||
| model.eval() | ||
| with torch.no_grad(): | ||
| output_before_fuse = model(dummy_input) | ||
|
|
||
| # Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP) | ||
| fuse_prequant_to_linear(model) | ||
|
|
||
| # Check if down_proj's pre_quant_scale was removed and fused into up_proj | ||
| for name, module in moe_down_proj_modules: | ||
| if hasattr(module, "input_quantizer"): | ||
| # Verify pre_quant_scale was removed from down_proj | ||
| assert not hasattr(module.input_quantizer, "_pre_quant_scale"), ( | ||
| f"{name}: down_proj pre_quant_scale should be removed after fusion" | ||
| ) | ||
| # Verify fused_with_prequant flag was set | ||
| assert hasattr(module, "fused_with_prequant") and module.fused_with_prequant, ( | ||
| f"{name}: down_proj should have fused_with_prequant flag set" | ||
| ) | ||
|
|
||
| # Verify output is close to the original output | ||
| with torch.no_grad(): | ||
| output_after_fuse = model(dummy_input) | ||
|
|
||
| # There will be some difference due to quantization errors after pre_quant_scale fusion | ||
| assert torch.allclose( | ||
| output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 | ||
| ), "Output should be similar before and after Qwen3 MoE fusion" | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's n_rep here?