diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index af80a0944..1e6912830 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -317,17 +317,23 @@ def main(args): tokenizer.padding_side = "left" # We only quantize the language model for VLMs other than the type supported above. - language_model, parent_model = get_language_model_from_vl(model) - if language_model is not None: + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: + language_model = language_model_lineage.pop(-1) + ancestors = language_model_lineage + # Apply disabled quant to all modules that are not part of language_model so we can exclude them during + # HF export. disabled_quant_cfg = { "quant_cfg": {"default": {"enable": False}}, "algorithm": "max", } - for name, child in parent_model.named_children(): - # Apply disabled quant to all children except language_model so we can exclude them during HF export. - if name != "language_model": - mtq.quantize(child, disabled_quant_cfg, forward_loop=None) + memo = set(ancestors) | {language_model} + for ancestor in ancestors: + for _, module in ancestor.named_children(): + if module not in memo: + mtq.quantize(module, disabled_quant_cfg, forward_loop=None) + memo.add(module) model = language_model model_type = get_model_type(model) @@ -492,10 +498,10 @@ def main(args): # For VL models, update full_model to use the quantized language model if is_nemotron_vl_model: - _, parent_model = get_language_model_from_vl(full_model) - if parent_model is not None: + language_model_lineage = get_language_model_from_vl(full_model) + if language_model_lineage is not None: print("Updating full_model with quantized language_model...") - parent_model.language_model = model + language_model_lineage[-2].language_model = model if args.verbose: mtq.print_quant_summary(full_model) diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 44f3c185c..fb5d9f8be 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -14,6 +14,8 @@ # limitations under the License. """Utility functions for model type detection and classification.""" +import torch.nn as nn + MODEL_NAME_TO_TYPE = { "GPT2": "gpt", "Mllama": "mllama", @@ -111,8 +113,8 @@ def is_multimodal_model(model): ) -def get_language_model_from_vl(model): - """Extract the language model component from a Vision-Language Model (VLM). +def get_language_model_from_vl(model) -> list[nn.Module] | None: + """Extract the language model lineage from a Vision-Language Model (VLM). This function handles the common patterns for accessing the language model component in various VLM architectures. It checks multiple possible locations where the @@ -122,36 +124,20 @@ def get_language_model_from_vl(model): model: The VLM model instance to extract the language model from Returns: - tuple: (language_model, parent_model) where: - - language_model: The extracted language model component, or None if not found - - parent_model: The parent model containing the language_model attribute + list: the lineage path towards the language model Examples: >>> # For LLaVA-style models - >>> lang_model, parent = get_language_model_from_vl(vlm_model) - >>> if lang_model is not None: - ... # Work with the language model component - ... quantized_lang_model = quantize(lang_model) - ... # Update the parent model - ... parent.language_model = quantized_lang_model + >>> lineage = get_language_model_from_vl(vlm_model) + >>> # lineage[0] is vlm_model + >>> # lineage[1] is vllm_model.language_model """ - # Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models) + # always prioritize model.model.langauge_model + if hasattr(model, "model") and hasattr(model.model, "language_model"): + return [model, model.model, model.model.language_model] + if hasattr(model, "language_model"): - # Check if it's a property that might need special handling - if isinstance(type(model).__dict__.get("language_model"), property): - # Some models have language_model as a property that points to model.model.language_model - if hasattr(model, "model") and hasattr(model.model, "language_model"): - return model.model.language_model, model.model - else: - # Property exists but no nested structure found - return model.language_model, model - else: - # Direct attribute access - return model.language_model, model - - # Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models) - elif hasattr(model, "model") and hasattr(model.model, "language_model"): - return model.model.language_model, model.model + return [model, model.language_model] # Pattern 3: No language_model found - return None, None + return None diff --git a/modelopt/torch/export/quant_utils.py b/modelopt/torch/export/quant_utils.py index 3e99a0e0a..aa4c3b7c3 100755 --- a/modelopt/torch/export/quant_utils.py +++ b/modelopt/torch/export/quant_utils.py @@ -17,6 +17,7 @@ import logging from collections.abc import Generator +from types import SimpleNamespace from typing import Any from warnings import warn @@ -1089,16 +1090,19 @@ def get_quant_config( layer_config_dict[name + ".quantization"] = quantization_format layer_config_dict[name + ".awq_block_size"] = block_size + not_enabled = SimpleNamespace(is_enabled=False) + # Find kv cache quant format if ( - hasattr(module, "k_bmm_quantizer") - or hasattr(module, "v_bmm_quantizer") - or (hasattr(module, "output_quantizer") and module.output_quantizer.is_enabled) + getattr(module, "k_bmm_quantizer", not_enabled).is_enabled + or getattr(module, "v_bmm_quantizer", not_enabled).is_enabled + or getattr(module, "output_quantizer", not_enabled).is_enabled ): + module_kv_quant = get_kv_cache_dtype(module) if kv_cache_format == QUANTIZATION_NONE: - kv_cache_format = get_kv_cache_dtype(module) + kv_cache_format = module_kv_quant else: - assert kv_cache_format == get_kv_cache_dtype(module), ( + assert kv_cache_format == module_kv_quant, ( "Do not support mixed precision kv cache quantization" ) diff --git a/modelopt/torch/export/unified_export_hf.py b/modelopt/torch/export/unified_export_hf.py index 7b102f4e0..6283bd35c 100644 --- a/modelopt/torch/export/unified_export_hf.py +++ b/modelopt/torch/export/unified_export_hf.py @@ -155,11 +155,12 @@ def _output_hook(module, input, output): model(fake_input, decoder_input_ids=decoder_fake_input) elif is_vl_model and "nemotron" in model_type: # For Nemotron VL models, try to run optimization on just the language model part - language_model, _ = get_language_model_from_vl(model) + language_model_lineage = get_language_model_from_vl(model) - if language_model is not None: + if language_model_lineage is not None: # Run optimization on just the language model with the same input format as regular LLMs # Use the same fake_input tensor that regular LLMs use + language_model = language_model_lineage[-1] print( f"Running optimization on language model with fake_input shape: {fake_input.shape}" ) @@ -474,7 +475,6 @@ def _export_hf_checkpoint( kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format) # Track if any layers are quantized to properly set exclude_modules - has_quantized_layers = False fsdp_module_to_reshard = None for _, sub_module in model.named_modules(): @@ -489,7 +489,6 @@ def _export_hf_checkpoint( fsdp_module_to_reshard = sub_module if get_quantization_format(sub_module) != QUANTIZATION_NONE: - has_quantized_layers = True if is_quantlinear(sub_module): with fsdp2_aware_weight_update(model, sub_module, reshard=False): _export_quantized_weight(sub_module, dtype) @@ -523,10 +522,6 @@ def _export_hf_checkpoint( quantized_state_dict, kv_cache_max_bound, kv_cache_format ) - # Check if any layers are quantized - if has_quantized_layers: - quant_config["quantization"].setdefault("exclude_modules", []).append("lm_head") - return quantized_state_dict, quant_config