Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 15 additions & 9 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,17 +317,23 @@ def main(args):
tokenizer.padding_side = "left"

# We only quantize the language model for VLMs other than the type supported above.
language_model, parent_model = get_language_model_from_vl(model)
if language_model is not None:
language_model_lineage = get_language_model_from_vl(full_model)
if language_model_lineage is not None:
language_model = language_model_lineage.pop(-1)
ancestors = language_model_lineage
# Apply disabled quant to all modules that are not part of language_model so we can exclude them during
# HF export.
disabled_quant_cfg = {
"quant_cfg": {"default": {"enable": False}},
"algorithm": "max",
}

for name, child in parent_model.named_children():
# Apply disabled quant to all children except language_model so we can exclude them during HF export.
if name != "language_model":
mtq.quantize(child, disabled_quant_cfg, forward_loop=None)
memo = set(ancestors) | {language_model}
for ancestor in ancestors:
for _, module in ancestor.named_children():
if module not in memo:
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
memo.add(module)

model = language_model
model_type = get_model_type(model)
Expand Down Expand Up @@ -492,10 +498,10 @@ def main(args):

# For VL models, update full_model to use the quantized language model
if is_nemotron_vl_model:
_, parent_model = get_language_model_from_vl(full_model)
if parent_model is not None:
language_model_lineage = get_language_model_from_vl(full_model)
if language_model_lineage is not None:
print("Updating full_model with quantized language_model...")
parent_model.language_model = model
language_model_lineage[-2].language_model = model

if args.verbose:
mtq.print_quant_summary(full_model)
Expand Down
42 changes: 14 additions & 28 deletions modelopt/torch/export/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
"""Utility functions for model type detection and classification."""

import torch.nn as nn

MODEL_NAME_TO_TYPE = {
"GPT2": "gpt",
"Mllama": "mllama",
Expand Down Expand Up @@ -111,8 +113,8 @@ def is_multimodal_model(model):
)


def get_language_model_from_vl(model):
"""Extract the language model component from a Vision-Language Model (VLM).
def get_language_model_from_vl(model) -> list[nn.Module] | None:
"""Extract the language model lineage from a Vision-Language Model (VLM).

This function handles the common patterns for accessing the language model component
in various VLM architectures. It checks multiple possible locations where the
Expand All @@ -122,36 +124,20 @@ def get_language_model_from_vl(model):
model: The VLM model instance to extract the language model from

Returns:
tuple: (language_model, parent_model) where:
- language_model: The extracted language model component, or None if not found
- parent_model: The parent model containing the language_model attribute
list: the lineage path towards the language model

Examples:
>>> # For LLaVA-style models
>>> lang_model, parent = get_language_model_from_vl(vlm_model)
>>> if lang_model is not None:
... # Work with the language model component
... quantized_lang_model = quantize(lang_model)
... # Update the parent model
... parent.language_model = quantized_lang_model
>>> lineage = get_language_model_from_vl(vlm_model)
>>> # lineage[0] is vlm_model
>>> # lineage[1] is vllm_model.language_model
"""
# Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models)
# always prioritize model.model.langauge_model
if hasattr(model, "model") and hasattr(model.model, "language_model"):
return [model, model.model, model.model.language_model]

if hasattr(model, "language_model"):
# Check if it's a property that might need special handling
if isinstance(type(model).__dict__.get("language_model"), property):
# Some models have language_model as a property that points to model.model.language_model
if hasattr(model, "model") and hasattr(model.model, "language_model"):
return model.model.language_model, model.model
else:
# Property exists but no nested structure found
return model.language_model, model
else:
# Direct attribute access
return model.language_model, model

# Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models)
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
return model.model.language_model, model.model
return [model, model.language_model]

# Pattern 3: No language_model found
return None, None
return None
14 changes: 9 additions & 5 deletions modelopt/torch/export/quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import logging
from collections.abc import Generator
from types import SimpleNamespace
from typing import Any
from warnings import warn

Expand Down Expand Up @@ -1086,16 +1087,19 @@ def get_quant_config(
layer_config_dict[name + ".quantization"] = quantization_format
layer_config_dict[name + ".awq_block_size"] = block_size

not_enabled = SimpleNamespace(is_enabled=False)

# Find kv cache quant format
if (
hasattr(module, "k_bmm_quantizer")
or hasattr(module, "v_bmm_quantizer")
or (hasattr(module, "output_quantizer") and module.output_quantizer.is_enabled)
getattr(module, "k_bmm_quantizer", not_enabled).is_enabled
or getattr(module, "v_bmm_quantizer", not_enabled).is_enabled
or getattr(module, "output_quantizer", not_enabled).is_enabled
):
module_kv_quant = get_kv_cache_dtype(module)
if kv_cache_format == QUANTIZATION_NONE:
kv_cache_format = get_kv_cache_dtype(module)
kv_cache_format = module_kv_quant
else:
assert kv_cache_format == get_kv_cache_dtype(module), (
assert kv_cache_format == module_kv_quant, (
"Do not support mixed precision kv cache quantization"
)

Expand Down
11 changes: 3 additions & 8 deletions modelopt/torch/export/unified_export_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,12 @@ def _output_hook(module, input, output):
model(fake_input, decoder_input_ids=decoder_fake_input)
elif is_vl_model and "nemotron" in model_type:
# For Nemotron VL models, try to run optimization on just the language model part
language_model, _ = get_language_model_from_vl(model)
language_model_lineage = get_language_model_from_vl(model)

if language_model is not None:
if language_model_lineage is not None:
# Run optimization on just the language model with the same input format as regular LLMs
# Use the same fake_input tensor that regular LLMs use
language_model = language_model_lineage[-1]
print(
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
)
Expand Down Expand Up @@ -474,7 +475,6 @@ def _export_hf_checkpoint(
kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format)

# Track if any layers are quantized to properly set exclude_modules
has_quantized_layers = False
fsdp_module_to_reshard = None

for _, sub_module in model.named_modules():
Expand All @@ -489,7 +489,6 @@ def _export_hf_checkpoint(
fsdp_module_to_reshard = sub_module

if get_quantization_format(sub_module) != QUANTIZATION_NONE:
has_quantized_layers = True
if is_quantlinear(sub_module):
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
_export_quantized_weight(sub_module, dtype)
Expand Down Expand Up @@ -523,10 +522,6 @@ def _export_hf_checkpoint(
quantized_state_dict, kv_cache_max_bound, kv_cache_format
)

# Check if any layers are quantized
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just double checking: Has this been handled by the quant config so we don't need to hardcode?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. If lm_head is not quantized, it will be in the exclude_modules natually. In addition, for models that have the language_model as a submodule may have the lm_head under language_model instead of under root model. In these cases, we'll have xxx.language_model.lm_head in exclude module instead of the hardcoded "lm_head", which is not correct to begin with.

if has_quantized_layers:
quant_config["quantization"].setdefault("exclude_modules", []).append("lm_head")

return quantized_state_dict, quant_config


Expand Down