From 9b10a73df34bd796918fd8382dc7098067281861 Mon Sep 17 00:00:00 2001 From: Shengliang Xu Date: Tue, 4 Nov 2025 09:55:19 +0000 Subject: [PATCH] [OMNIML-2917] handle lm_head and other un-quantized modules correctly This is change set 2 from working on OMNIML-2917. Two correlated changes: 1. when we just quantize the langauge_model submodule, correctly disable quantization of all other modules, we do not need to hard code anything 2. When we export quantized model to hf unified format, we hard code the exclusion of "lm_head". With the change set 1 where we use the full model for export config generation, we can natually exclude lm_head if it is not quantized. Therefore, remove the hard coded lm_head inclusion in the exclusion list. Signed-off-by: Shengliang Xu --- examples/llm_ptq/hf_ptq.py | 24 ++++++++----- modelopt/torch/export/model_utils.py | 42 ++++++++-------------- modelopt/torch/export/quant_utils.py | 14 +++++--- modelopt/torch/export/unified_export_hf.py | 11 ++---- 4 files changed, 41 insertions(+), 50 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index af80a0944..1e6912830 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -317,17 +317,23 @@ def main(args): tokenizer.padding_side = "left" # We only quantize the language model for VLMs other than the type supported above. - language_model, parent_model = get_language_model_from_vl(model) - if language_model is not None: + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + language_model = language_model_lineage.pop(-1) + ancestors = language_model_lineage + # Apply disabled quant to all modules that are not part of language_model so we can exclude them during + # HF export. disabled_quant_cfg = { "quant_cfg": {"default": {"enable": False}}, "algorithm": "max", } - for name, child in parent_model.named_children(): - # Apply disabled quant to all children except language_model so we can exclude them during HF export. - if name != "language_model": - mtq.quantize(child, disabled_quant_cfg, forward_loop=None) + memo = set(ancestors) | {language_model} + for ancestor in ancestors: + for _, module in ancestor.named_children(): + if module not in memo: + mtq.quantize(module, disabled_quant_cfg, forward_loop=None) + memo.add(module) model = language_model model_type = get_model_type(model) @@ -492,10 +498,10 @@ def main(args): # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: - _, parent_model = get_language_model_from_vl(full_model) - if parent_model is not None: + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: print("Updating full_model with quantized language_model...") - parent_model.language_model = model + language_model_lineage[-2].language_model = model if args.verbose: mtq.print_quant_summary(full_model) diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 44f3c185c..fb5d9f8be 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -14,6 +14,8 @@ # limitations under the License. """Utility functions for model type detection and classification.""" +import torch.nn as nn + MODEL_NAME_TO_TYPE = { "GPT2": "gpt", "Mllama": "mllama", @@ -111,8 +113,8 @@ def is_multimodal_model(model): ) -def get_language_model_from_vl(model): - """Extract the language model component from a Vision-Language Model (VLM). +def get_language_model_from_vl(model) -> list[nn.Module] | None: + """Extract the language model lineage from a Vision-Language Model (VLM). This function handles the common patterns for accessing the language model component in various VLM architectures. It checks multiple possible locations where the @@ -122,36 +124,20 @@ def get_language_model_from_vl(model): model: The VLM model instance to extract the language model from Returns: - tuple: (language_model, parent_model) where: - - language_model: The extracted language model component, or None if not found - - parent_model: The parent model containing the language_model attribute + list: the lineage path towards the language model Examples: >>> # For LLaVA-style models - >>> lang_model, parent = get_language_model_from_vl(vlm_model) - >>> if lang_model is not None: - ... # Work with the language model component - ... quantized_lang_model = quantize(lang_model) - ... # Update the parent model - ... parent.language_model = quantized_lang_model + >>> lineage = get_language_model_from_vl(vlm_model) + >>> # lineage[0] is vlm_model + >>> # lineage[1] is vllm_model.language_model """ - # Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models) + # always prioritize model.model.langauge_model + if hasattr(model, "model") and hasattr(model.model, "language_model"): + return [model, model.model, model.model.language_model] + if hasattr(model, "language_model"): - # Check if it's a property that might need special handling - if isinstance(type(model).__dict__.get("language_model"), property): - # Some models have language_model as a property that points to model.model.language_model - if hasattr(model, "model") and hasattr(model.model, "language_model"): - return model.model.language_model, model.model - else: - # Property exists but no nested structure found - return model.language_model, model - else: - # Direct attribute access - return model.language_model, model - - # Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models) - elif hasattr(model, "model") and hasattr(model.model, "language_model"): - return model.model.language_model, model.model + return [model, model.language_model] # Pattern 3: No language_model found - return None, None + return None diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index f474ebea4..051860b95 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -17,6 +17,7 @@ import logging from collections.abc import Generator +from types import SimpleNamespace from typing import Any from warnings import warn @@ -1086,16 +1087,19 @@ def get_quant_config( layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size + not_enabled = SimpleNamespace(is_enabled=False) + # Find kv cache quant format if ( - hasattr(module, "k_bmm_quantizer") - or hasattr(module, "v_bmm_quantizer") - or (hasattr(module, "output_quantizer") and module.output_quantizer.is_enabled) + getattr(module, "k_bmm_quantizer", not_enabled).is_enabled + or getattr(module, "v_bmm_quantizer", not_enabled).is_enabled + or getattr(module, "output_quantizer", not_enabled).is_enabled ): + module_kv_quant = get_kv_cache_dtype(module) if kv_cache_format == QUANTIZATION_NONE: - kv_cache_format = get_kv_cache_dtype(module) + kv_cache_format = module_kv_quant else: - assert kv_cache_format == get_kv_cache_dtype(module), ( + assert kv_cache_format == module_kv_quant, ( "Do not support mixed precision kv cache quantization" ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 7b102f4e0..6283bd35c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -155,11 +155,12 @@ def _output_hook(module, input, output): model(fake_input, decoder_input_ids=decoder_fake_input) elif is_vl_model and "nemotron" in model_type: # For Nemotron VL models, try to run optimization on just the language model part - language_model, _ = get_language_model_from_vl(model) + language_model_lineage = get_language_model_from_vl(model) - if language_model is not None: + if language_model_lineage is not None: # Run optimization on just the language model with the same input format as regular LLMs # Use the same fake_input tensor that regular LLMs use + language_model = language_model_lineage[-1] print( f"Running optimization on language model with fake_input shape: {fake_input.shape}" ) @@ -474,7 +475,6 @@ def _export_hf_checkpoint( kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) # Track if any layers are quantized to properly set exclude_modules - has_quantized_layers = False fsdp_module_to_reshard = None for _, sub_module in model.named_modules(): @@ -489,7 +489,6 @@ def _export_hf_checkpoint( fsdp_module_to_reshard = sub_module if get_quantization_format(sub_module) != QUANTIZATION_NONE: - has_quantized_layers = True if is_quantlinear(sub_module): with fsdp2_aware_weight_update(model, sub_module, reshard=False): _export_quantized_weight(sub_module, dtype) @@ -523,10 +522,6 @@ def _export_hf_checkpoint( quantized_state_dict, kv_cache_max_bound, kv_cache_format ) - # Check if any layers are quantized - if has_quantized_layers: - quant_config["quantization"].setdefault("exclude_modules", []).append("lm_head") - return quantized_state_dict, quant_config