Skip to content

Commit a113bea

Browse files
authored
[OMNIML-2917] handle lm_head and other un-quantized modules correctly (#504)
## What does this PR do? **Type of change:** Bug fix. **Overview:** This is change set 2 from working on OMNIML-2917. Two correlated changes: 1. when we just quantize the langauge_model submodule, correctly disable quantization of all other modules, we do not need to hard code anything 2. When we export quantized model to hf unified format, we hard code the exclusion of "lm_head". With the change set 1 where we use the full model for export config generation, we can natually exclude lm_head if it is not quantized. Therefore, remove the hard coded lm_head inclusion in the exclusion list. ## Testing Correctly exported Llama 3.1 70B, Qwen3 VL MoE, Nemotron Super, Llama4 Scout, NVIDIA-Nemotron-Nano-12B-v2-VL-BF16 --------- Signed-off-by: Shengliang Xu <[email protected]>
1 parent 479f729 commit a113bea

File tree

4 files changed

+41
-50
lines changed

4 files changed

+41
-50
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -318,17 +318,23 @@ def main(args):
318318
tokenizer.padding_side = "left"
319319

320320
# We only quantize the language model for VLMs other than the type supported above.
321-
language_model, parent_model = get_language_model_from_vl(model)
322-
if language_model is not None:
321+
language_model_lineage = get_language_model_from_vl(full_model)
322+
if language_model_lineage is not None:
323+
language_model = language_model_lineage.pop(-1)
324+
ancestors = language_model_lineage
325+
# Apply disabled quant to all modules that are not part of language_model so we can exclude them during
326+
# HF export.
323327
disabled_quant_cfg = {
324328
"quant_cfg": {"default": {"enable": False}},
325329
"algorithm": "max",
326330
}
327331

328-
for name, child in parent_model.named_children():
329-
# Apply disabled quant to all children except language_model so we can exclude them during HF export.
330-
if name != "language_model":
331-
mtq.quantize(child, disabled_quant_cfg, forward_loop=None)
332+
memo = set(ancestors) | {language_model}
333+
for ancestor in ancestors:
334+
for _, module in ancestor.named_children():
335+
if module not in memo:
336+
mtq.quantize(module, disabled_quant_cfg, forward_loop=None)
337+
memo.add(module)
332338

333339
model = language_model
334340
model_type = get_model_type(model)
@@ -493,10 +499,10 @@ def main(args):
493499

494500
# For VL models, update full_model to use the quantized language model
495501
if is_nemotron_vl_model:
496-
_, parent_model = get_language_model_from_vl(full_model)
497-
if parent_model is not None:
502+
language_model_lineage = get_language_model_from_vl(full_model)
503+
if language_model_lineage is not None:
498504
print("Updating full_model with quantized language_model...")
499-
parent_model.language_model = model
505+
language_model_lineage[-2].language_model = model
500506

501507
if args.verbose:
502508
mtq.print_quant_summary(full_model)

modelopt/torch/export/model_utils.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
# limitations under the License.
1515
"""Utility functions for model type detection and classification."""
1616

17+
import torch.nn as nn
18+
1719
MODEL_NAME_TO_TYPE = {
1820
"GPT2": "gpt",
1921
"Mllama": "mllama",
@@ -111,8 +113,8 @@ def is_multimodal_model(model):
111113
)
112114

113115

114-
def get_language_model_from_vl(model):
115-
"""Extract the language model component from a Vision-Language Model (VLM).
116+
def get_language_model_from_vl(model) -> list[nn.Module] | None:
117+
"""Extract the language model lineage from a Vision-Language Model (VLM).
116118
117119
This function handles the common patterns for accessing the language model component
118120
in various VLM architectures. It checks multiple possible locations where the
@@ -122,36 +124,20 @@ def get_language_model_from_vl(model):
122124
model: The VLM model instance to extract the language model from
123125
124126
Returns:
125-
tuple: (language_model, parent_model) where:
126-
- language_model: The extracted language model component, or None if not found
127-
- parent_model: The parent model containing the language_model attribute
127+
list: the lineage path towards the language model
128128
129129
Examples:
130130
>>> # For LLaVA-style models
131-
>>> lang_model, parent = get_language_model_from_vl(vlm_model)
132-
>>> if lang_model is not None:
133-
... # Work with the language model component
134-
... quantized_lang_model = quantize(lang_model)
135-
... # Update the parent model
136-
... parent.language_model = quantized_lang_model
131+
>>> lineage = get_language_model_from_vl(vlm_model)
132+
>>> # lineage[0] is vlm_model
133+
>>> # lineage[1] is vlm_model.language_model
137134
"""
138-
# Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models)
135+
# always prioritize model.model.langauge_model
136+
if hasattr(model, "model") and hasattr(model.model, "language_model"):
137+
return [model, model.model, model.model.language_model]
138+
139139
if hasattr(model, "language_model"):
140-
# Check if it's a property that might need special handling
141-
if isinstance(type(model).__dict__.get("language_model"), property):
142-
# Some models have language_model as a property that points to model.model.language_model
143-
if hasattr(model, "model") and hasattr(model.model, "language_model"):
144-
return model.model.language_model, model.model
145-
else:
146-
# Property exists but no nested structure found
147-
return model.language_model, model
148-
else:
149-
# Direct attribute access
150-
return model.language_model, model
151-
152-
# Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models)
153-
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
154-
return model.model.language_model, model.model
140+
return [model, model.language_model]
155141

156142
# Pattern 3: No language_model found
157-
return None, None
143+
return None

modelopt/torch/export/quant_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import logging
1919
from collections.abc import Generator
20+
from types import SimpleNamespace
2021
from typing import Any
2122
from warnings import warn
2223

@@ -1121,16 +1122,19 @@ def get_quant_config(
11211122
layer_config_dict[name + ".quantization"] = quantization_format
11221123
layer_config_dict[name + ".awq_block_size"] = block_size
11231124

1125+
not_enabled = SimpleNamespace(is_enabled=False)
1126+
11241127
# Find kv cache quant format
11251128
if (
1126-
hasattr(module, "k_bmm_quantizer")
1127-
or hasattr(module, "v_bmm_quantizer")
1128-
or (hasattr(module, "output_quantizer") and module.output_quantizer.is_enabled)
1129+
getattr(module, "k_bmm_quantizer", not_enabled).is_enabled
1130+
or getattr(module, "v_bmm_quantizer", not_enabled).is_enabled
1131+
or getattr(module, "output_quantizer", not_enabled).is_enabled
11291132
):
1133+
module_kv_quant = get_kv_cache_dtype(module)
11301134
if kv_cache_format == QUANTIZATION_NONE:
1131-
kv_cache_format = get_kv_cache_dtype(module)
1135+
kv_cache_format = module_kv_quant
11321136
else:
1133-
assert kv_cache_format == get_kv_cache_dtype(module), (
1137+
assert kv_cache_format == module_kv_quant, (
11341138
"Do not support mixed precision kv cache quantization"
11351139
)
11361140

modelopt/torch/export/unified_export_hf.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,12 @@ def _output_hook(module, input, output):
155155
model(fake_input, decoder_input_ids=decoder_fake_input)
156156
elif is_vl_model and "nemotron" in model_type:
157157
# For Nemotron VL models, try to run optimization on just the language model part
158-
language_model, _ = get_language_model_from_vl(model)
158+
language_model_lineage = get_language_model_from_vl(model)
159159

160-
if language_model is not None:
160+
if language_model_lineage is not None:
161161
# Run optimization on just the language model with the same input format as regular LLMs
162162
# Use the same fake_input tensor that regular LLMs use
163+
language_model = language_model_lineage[-1]
163164
print(
164165
f"Running optimization on language model with fake_input shape: {fake_input.shape}"
165166
)
@@ -472,7 +473,6 @@ def _export_hf_checkpoint(
472473
kv_cache_max_bound = cache_bound_mapping.get(kv_cache_format)
473474

474475
# Track if any layers are quantized to properly set exclude_modules
475-
has_quantized_layers = False
476476
fsdp_module_to_reshard = None
477477

478478
for _, sub_module in model.named_modules():
@@ -491,7 +491,6 @@ def _export_hf_checkpoint(
491491
continue
492492

493493
if get_quantization_format(sub_module) != QUANTIZATION_NONE:
494-
has_quantized_layers = True
495494
if is_quantlinear(sub_module):
496495
with fsdp2_aware_weight_update(model, sub_module, reshard=False):
497496
_export_quantized_weight(sub_module, dtype)
@@ -525,10 +524,6 @@ def _export_hf_checkpoint(
525524
quantized_state_dict, kv_cache_max_bound, kv_cache_format, is_modelopt_qlora
526525
)
527526

528-
# Check if any layers are quantized
529-
if has_quantized_layers:
530-
quant_config["quantization"].setdefault("exclude_modules", []).append("lm_head")
531-
532527
return quantized_state_dict, quant_config
533528

534529

0 commit comments

Comments
 (0)