diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 885a12582..a5892fdf0 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -478,7 +478,7 @@ def _get_quantization_from_layer(layer, quantizer_attr_names: QuantizerAttrNames if input_quantizer is not None and hasattr(input_quantizer, "_pre_quant_scale"): return QUANTIZATION_NVFP4_AWQ - if getattr(layer, "fused_with_layernorm", False): + if getattr(layer, "fused_with_prequant", False): return QUANTIZATION_NVFP4_AWQ assert input_quantizer is not None, ( f"input_quantizer is None for {quantizer_attr_names}" @@ -923,18 +923,149 @@ def all_items_same(item_list): return all(x == item_list[0] for x in item_list) +# Format: (list of target modules, tuple of (linear_to_fuse_into, linear_from_with_scale)) +PQS_FUSE_MODULE_MAPPING = [ + # Attention: Fuse o_proj's pre_quant_scale into v_proj's output dimension + # Mathematical equivalence: + # Before: o_proj_out = [attn @ (v_proj_in @ v_proj.W^T)^T * scale] @ o_proj.W^T + # After: o_proj_out = [attn @ (v_proj_in @ (v_proj.W * scale)^T)^T] @ o_proj.W^T + (["LlamaAttention", "Qwen3Attention", "Qwen3MoeAttention"], ("v_proj", "o_proj")), + # MLP: Fuse down_proj's pre_quant_scale into up_proj's output dimension + # Mathematical equivalence: + # Before: down_proj_out = {[act_fn(self.gate_proj(x)) * up_proj(x)] * scale} @ down_proj.W^T + # After: down_proj_out = {[act_fn(self.gate_proj(x)) * (up_proj(x) * scale)]} @ down_proj.W^T + (["LlamaMLP", "Qwen3MLP", "Qwen3MoeMLP"], ("up_proj", "down_proj")), +] + + +def fuse_prequant_to_linear(model: torch.nn.Module, fuse_grouped_heads=False): + """Fuse pre_quant_scale to the linear weights if possible. + + For example, we can fuse the pre_quant_scale of o_proj to the output_dimension of v_proj, such that + the results are mathematically equivalent to the following:: + + out_proj.input = (attn_weights @ v_proj.output) + out_proj.output = (out_proj.input * pre_quant_scale) * out_proj.weight + = attn_weights @ (v_proj.output * pre_quant_scale) * out_proj.weight + + For GQA/MQA models where v_proj output dimension < o_proj input dimension, + the pre_quant_scale is averaged across the repeated head groups and then the + o_proj's pre_quant_scale is updated to maintain mathematical equivalence. + + Args: + model: The model to fuse pre_quant_scale to. + fuse_grouped_heads: If True, fuse the pre_quant_scale even if dimension between pre_quant_scale + and linear weights is not the same. This is useful for GQA/MQA models but may lead to accuracy + drop. + + Note: + Fuse_grouped_heads is useful for GQA/MQA models but may lead to accuracy drop. + """ + # Fuse pre_quant_scale to the linear weights + for _, module in model.named_modules(): + for module_map in PQS_FUSE_MODULE_MAPPING: + target_module_list = module_map[0] + linear_pair = module_map[1] + if any(module_name in type(module).__name__ for module_name in target_module_list): + linear_fuse_into = module.get_submodule(linear_pair[0]) + linear_pqs_from = module.get_submodule(linear_pair[1]) + if hasattr(linear_pqs_from, "input_quantizer") and hasattr( + linear_pqs_from.input_quantizer, "_pre_quant_scale" + ): + pre_quant_scale = linear_pqs_from.input_quantizer._pre_quant_scale + + # for GQA/MQA models, we can apply averaging to the pre_quant_scale for shared head groups + if pre_quant_scale.numel() != linear_fuse_into.weight.shape[-2]: + if ( + not fuse_grouped_heads + or "attention" not in type(module).__name__.lower() + ): + warn( + f"Skipping pattern fuse prequant for {type(module).__name__}" + f"pqs dim {pre_quant_scale.numel()} != out_ch dim {linear_fuse_into.weight.shape[-2]}" + ) + continue + config = module.config + num_kv_heads = config.num_key_value_heads + kv_head_dim = linear_fuse_into.weight.shape[0] // num_kv_heads + n_rep = pre_quant_scale.numel() // num_kv_heads // kv_head_dim + + # Reshape:(num_kv_heads, n_rep, kv_head_dim) + averaged_scale = pre_quant_scale.view( + num_kv_heads, n_rep, kv_head_dim + ).mean(dim=1) + + # To update o_proj, we need to repeat back to original shape + repeated_scale = ( + averaged_scale.unsqueeze(1) + .expand(num_kv_heads, n_rep, kv_head_dim) + .reshape(-1) + ) + + def _update_pre_quant_scale(module, new_pre_quant_scale): + old_pre_quant_scale = module.input_quantizer._pre_quant_scale + module.weight = nn.Parameter( + module.weight + * old_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + / new_pre_quant_scale.to( + dtype=module.weight.dtype, device=module.weight.device + ) + ) + module.input_quantizer.pre_quant_scale = new_pre_quant_scale + + # Redo weights collection + module.weight_quantizer.reset_amax() + enable_stats_collection(module.weight_quantizer) + module.weight_quantizer(module.weight) + finish_stats_collection(module.weight_quantizer) + + # Update o_proj's pre_quant_scale + _update_pre_quant_scale(linear_pqs_from, repeated_scale) + + # Use averaged scale (flattened) for v_proj fusion + pre_quant_scale = averaged_scale.reshape(-1) + + # Fuse the pre_quant_scale to weight + linear_fuse_into.weight = torch.nn.Parameter( + linear_fuse_into.weight * pre_quant_scale.view(-1, 1) + ) + if hasattr(linear_fuse_into, "bias") and linear_fuse_into.bias is not None: + linear_fuse_into.bias = torch.nn.Parameter( + linear_fuse_into.bias * pre_quant_scale + ) + + delattr(linear_pqs_from.input_quantizer, "_pre_quant_scale") + setattr(linear_pqs_from, "fused_with_prequant", True) + + def fuse_prequant_layernorm( layernorm_module: torch.nn.Module, modules: list[torch.Tensor], ): - """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted.""" + """Scales layernorm weights with avg_pre_quant_scale of the modules list and sets pre_quant_scales to be deleted. + + original: + layernorm_output = (normalization(input) * weight) + bias + layernorm_output_scaled = layernorm_output * pre_quant_scale + + fused: + fused_weight = weight * avg_pre_quant_scale + fused_bias = bias * avg_pre_quant_scale + layernorm_output_scaled = (normalization(input) * fused_weight) + fused_bias + """ layernorm_module.weight = torch.nn.Parameter( layernorm_module.weight * getattr(modules[0].input_quantizer, "_pre_quant_scale") ) + if hasattr(layernorm_module, "bias"): + layernorm_module.bias = torch.nn.Parameter( + layernorm_module.bias * getattr(modules[0].input_quantizer, "_pre_quant_scale") + ) # Pre_quant_scales of modules must not be exported, since they have been fused with layernorm for module in modules: delattr(module.input_quantizer, "_pre_quant_scale") - setattr(module, "fused_with_layernorm", True) + setattr(module, "fused_with_prequant", True) def preprocess_linear_fusion(modules: list[torch.nn.Module], resmooth_only=False): diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index f966ffac6..ff7e2dc72 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -57,6 +57,7 @@ from .plugins import export_spec_ckpt_config, export_spec_ckpt_state_dict, spec_opt_only from .quant_utils import ( fuse_prequant_layernorm, + fuse_prequant_to_linear, get_activation_scaling_factor, get_quant_config, get_quantization_format, @@ -106,6 +107,10 @@ def _output_hook(module, input, output): fused_linears = {} module_names = set() + # Fuse pre_quant_scale to the linear weights if possible + if "NVFP4_AWQ" in quantization_format: + fuse_prequant_to_linear(model) + for name, module in model.named_modules(): module_names.add(name) diff --git a/tests/gpu/torch/export/test_quant_utils.py b/tests/gpu/torch/export/test_quant_utils.py new file mode 100644 index 000000000..16b4f524c --- /dev/null +++ b/tests/gpu/torch/export/test_quant_utils.py @@ -0,0 +1,193 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + +pytest.importorskip("transformers") + +from transformers import LlamaConfig, LlamaForCausalLM + +import modelopt.torch.quantization as mtq +from modelopt.torch.export.quant_utils import fuse_prequant_to_linear + + +def get_tiny_llama(attention_heads=4, key_value_heads=4): + """Create a tiny Llama model for testing.""" + config = LlamaConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=2, + num_attention_heads=attention_heads, + num_key_value_heads=key_value_heads, + max_position_embeddings=128, + vocab_size=256, + ) + return LlamaForCausalLM(config) + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT4_AWQ_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + ], +) +@pytest.mark.parametrize( + "attention_kv_heads_pair", + [ + (4, 4), # MHA + (4, 2), # GQA + (4, 1), # MQA + ], +) +def test_pattern_fuse_prequant(quant_config, attention_kv_heads_pair): + """Test pattern_fuse_prequant on modules from a tiny Llama model.""" + model = get_tiny_llama(attention_kv_heads_pair[0], attention_kv_heads_pair[1]).to("cuda") + + # Quantize the model + dummy_input = torch.randint(0, 256, (1, 16), device="cuda") + mtq.quantize(model, quant_config, lambda m: m(dummy_input)) + + # Run forward pass before fusion + model.eval() + with torch.no_grad(): + output_before_fuse = model(dummy_input) + + traget_module_name_list = [ + "model.layers.0.self_attn.o_proj", + "model.layers.0.mlp.down_proj", + "model.layers.1.self_attn.o_proj", + "model.layers.1.mlp.down_proj", + ] + + # Apply fusion + fuse_prequant_to_linear(model, fuse_grouped_heads=True) + + # Check if pre_quant_scale and fused_with_prequant flag are removed correctly + for target_module_name in traget_module_name_list: + target_module = model.get_submodule(target_module_name) + + # Verify pre_quant_scale was removed + assert not hasattr(target_module.input_quantizer, "_pre_quant_scale"), ( + f"{target_module_name}: pre_quant_scale should be removed after fusion" + ) + + # Verify fused_with_prequant flag was set + assert ( + hasattr(target_module, "fused_with_prequant") and target_module.fused_with_prequant + ), f"{target_module_name}: fused_with_prequant flag should be set" + + # Verify output is close to the original output + with torch.no_grad(): + output_after_fuse = model(dummy_input) + # There will be some small difference due to quantization errors after pre_quant_scale fusion to the weights + assert torch.allclose( + output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 + ), "Output should be the same before and after fusion" + + +@pytest.mark.parametrize( + "quant_config", + [ + mtq.INT4_AWQ_CFG, + mtq.NVFP4_AWQ_LITE_CFG, + ], +) +def test_pattern_fuse_prequant_moe(quant_config): + """Test pattern_fuse_prequant on Qwen3 MoE sparse MLP.""" + pytest.importorskip("transformers") + from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM + + # Create a tiny Qwen3MoE model for testing + config = Qwen3MoeConfig( + hidden_size=128, + intermediate_size=256, + moe_intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + num_experts=4, + num_experts_per_tok=2, + max_position_embeddings=128, + vocab_size=256, + shared_expert_intermediate_size=256, + ) + model = Qwen3MoeForCausalLM(config).to("cuda") + + # Quantize the model + dummy_input = torch.randint(0, 256, (1, 16), device="cuda") + mtq.quantize(model, quant_config, lambda m: m(dummy_input)) + + # Collect MoE expert modules to verify (down_proj should be fused) + moe_down_proj_modules = [] + moe_gate_proj_modules = [] + moe_up_proj_modules = [] + for name, module in model.named_modules(): + if "mlp" in name and "experts" in name: + if "gate_proj" in name and not any(x in name for x in ["weight", "quantizer"]): + moe_gate_proj_modules.append((name, module)) + elif "down_proj" in name and not any(x in name for x in ["weight", "quantizer"]): + moe_down_proj_modules.append((name, module)) + elif "up_proj" in name and not any(x in name for x in ["weight", "quantizer"]): + moe_up_proj_modules.append((name, module)) + + # Verify experts have pre_quant_scale before fusion + for name, module in moe_gate_proj_modules: + if hasattr(module, "input_quantizer"): + assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: gate_proj should have pre_quant_scale before fusion" + ) + + for name, module in moe_up_proj_modules: + if hasattr(module, "input_quantizer"): + assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: up_proj should have pre_quant_scale before fusion" + ) + + for name, module in moe_down_proj_modules: + if hasattr(module, "input_quantizer"): + assert hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: down_proj should have pre_quant_scale before fusion" + ) + + # Run forward pass before fusion + model.eval() + with torch.no_grad(): + output_before_fuse = model(dummy_input) + + # Apply fusion (fuse_mismatch_dim only needed for GQA/MQA attention, not for MLP) + fuse_prequant_to_linear(model) + + # Check if down_proj's pre_quant_scale was removed and fused into up_proj + for name, module in moe_down_proj_modules: + if hasattr(module, "input_quantizer"): + # Verify pre_quant_scale was removed from down_proj + assert not hasattr(module.input_quantizer, "_pre_quant_scale"), ( + f"{name}: down_proj pre_quant_scale should be removed after fusion" + ) + # Verify fused_with_prequant flag was set + assert hasattr(module, "fused_with_prequant") and module.fused_with_prequant, ( + f"{name}: down_proj should have fused_with_prequant flag set" + ) + + # Verify output is close to the original output + with torch.no_grad(): + output_after_fuse = model(dummy_input) + + # There will be some difference due to quantization errors after pre_quant_scale fusion + assert torch.allclose( + output_before_fuse.logits, output_after_fuse.logits, rtol=1e-1, atol=5e-1 + ), "Output should be similar before and after Qwen3 MoE fusion"