Skip to content

Commit 446e135

Browse files
committed
add a util function to extract language model from VLM, update changelog
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 60a698a commit 446e135

File tree

5 files changed

+64
-29
lines changed

5 files changed

+64
-29
lines changed

CHANGELOG.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Model Optimizer Changelog (Linux)
1313
- Allow specifying ``calib_seq`` in ``examples/llm_ptq`` to set the maximum sequence length for calibration.
1414
- Add support for MCore MoE PTQ/QAT/QAD.
1515
- Add support for multi-node PTQ and export with FSDP2 in ``examples/llm_ptq/multinode_ptq.py``. See `examples/llm_ptq/README.md <https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/examples/llm_ptq#multi-node-post-training-quantization-with-fsdp2>`_ for more details.
16+
- Add support for Nemotron Nano VL v1 & v2 models in FP8/NVFP4 PTQ workflow.
1617

1718
**Documentation**
1819

examples/llm_ptq/hf_ptq.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
export_tensorrt_llm_checkpoint,
5151
get_model_type,
5252
)
53-
from modelopt.torch.export.model_utils import is_multimodal_model
53+
from modelopt.torch.export.model_utils import get_language_model_from_vl, is_multimodal_model
5454
from modelopt.torch.quantization.config import need_calibration
5555
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
5656
from modelopt.torch.quantization.utils import is_quantized
@@ -317,15 +317,8 @@ def main(args):
317317
tokenizer.padding_side = "left"
318318

319319
# We only quantize the language model for VLMs other than the type supported above.
320-
if hasattr(model, "language_model"):
321-
parent_model = model # llama4 case
322-
if isinstance(type(model).__dict__.get("language_model"), property):
323-
assert hasattr(model, "model") and hasattr(model.model, "language_model"), (
324-
"Expected language_model in model.model, but attribute not found. "
325-
"This may indicate an unsupported model structure."
326-
)
327-
parent_model = model.model # gemma3, qwen2.5 VL case
328-
320+
language_model, parent_model = get_language_model_from_vl(model)
321+
if language_model is not None:
329322
disabled_quant_cfg = {
330323
"quant_cfg": {"default": {"enable": False}},
331324
"algorithm": "max",
@@ -336,7 +329,7 @@ def main(args):
336329
if name != "language_model":
337330
mtq.quantize(child, disabled_quant_cfg, forward_loop=None)
338331

339-
model = model.language_model
332+
model = language_model
340333
model_type = get_model_type(model)
341334

342335
if model_type == "phi4mm":
@@ -538,9 +531,11 @@ def main(args):
538531
model = quantize_model(model, quant_cfg, args, calib_dataloader, calibration_only)
539532

540533
# For VL models, update full_model to use the quantized language model
541-
if is_nemotron_vl and hasattr(full_model, "language_model"):
542-
print("Updating full_model with quantized language_model...")
543-
full_model.language_model = model
534+
if is_nemotron_vl:
535+
_, parent_model = get_language_model_from_vl(full_model)
536+
if parent_model is not None:
537+
print("Updating full_model with quantized language_model...")
538+
parent_model.language_model = model
544539

545540
if args.verbose:
546541
mtq.print_quant_summary(model)

examples/llm_ptq/vlm_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,3 +231,5 @@ def run_text_only_generation(model, tokenizer, question, generation_config, mode
231231
except Exception as e:
232232
print(f"Text-only generation failed: {e}")
233233
return None
234+
235+

modelopt/torch/export/model_utils.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
{MODEL_NAME_TO_TYPE=}
6161
"""
6262

63-
__all__ = ["get_model_type", "is_multimodal_model"]
63+
__all__ = ["get_model_type", "is_multimodal_model", "get_language_model_from_vl"]
6464

6565

6666
def get_model_type(model):
@@ -109,3 +109,49 @@ def is_multimodal_model(model):
109109
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
110110
) # Image embedding layers
111111
)
112+
113+
114+
def get_language_model_from_vl(model):
115+
"""Extract the language model component from a Vision-Language Model (VLM).
116+
117+
This function handles the common patterns for accessing the language model component
118+
in various VLM architectures. It checks multiple possible locations where the
119+
language model might be stored.
120+
121+
Args:
122+
model: The VLM model instance to extract the language model from
123+
124+
Returns:
125+
tuple: (language_model, parent_model) where:
126+
- language_model: The extracted language model component, or None if not found
127+
- parent_model: The parent model containing the language_model attribute
128+
129+
Examples:
130+
>>> # For LLaVA-style models
131+
>>> lang_model, parent = get_language_model_from_vl(vlm_model)
132+
>>> if lang_model is not None:
133+
... # Work with the language model component
134+
... quantized_lang_model = quantize(lang_model)
135+
... # Update the parent model
136+
... parent.language_model = quantized_lang_model
137+
"""
138+
# Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models)
139+
if hasattr(model, "language_model"):
140+
# Check if it's a property that might need special handling
141+
if isinstance(type(model).__dict__.get("language_model"), property):
142+
# Some models have language_model as a property that points to model.model.language_model
143+
if hasattr(model, "model") and hasattr(model.model, "language_model"):
144+
return model.model.language_model, model.model
145+
else:
146+
# Property exists but no nested structure found
147+
return model.language_model, model
148+
else:
149+
# Direct attribute access
150+
return model.language_model, model
151+
152+
# Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models)
153+
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
154+
return model.model.language_model, model.model
155+
156+
# Pattern 3: No language_model found
157+
return None, None

modelopt/torch/export/unified_export_hf.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
is_quantlinear,
4444
set_expert_quantizer_amax,
4545
)
46+
from .model_utils import get_language_model_from_vl, is_multimodal_model
4647
from .model_config import (
4748
KV_CACHE_FP8,
4849
KV_CACHE_NVFP4,
@@ -136,11 +137,7 @@ def _output_hook(module, input, output):
136137
decoder_fake_input = fake_input
137138

138139
# Check if this is a VL model that needs special input handling
139-
is_vl_model = (
140-
hasattr(model.config, "vision_config")
141-
or hasattr(model, "vision_model")
142-
or "nemotron" in getattr(model, "name_or_path", "").lower()
143-
)
140+
is_vl_model = is_multimodal_model(model)
144141

145142
if model_type.startswith("whisper"):
146143
# For Whisper models, we need to pass a fake input with the specific sequence length
@@ -159,16 +156,10 @@ def _output_hook(module, input, output):
159156
model(fake_input, decoder_input_ids=decoder_fake_input)
160157
elif is_vl_model:
161158
# For VL models, try to run optimization on just the language model part
162-
language_model = None
163-
if hasattr(model, "language_model"):
164-
language_model = model.language_model
165-
print(
166-
"Found language_model attribute - running optimization on language model only"
167-
)
168-
elif hasattr(model, "model") and hasattr(model.model, "language_model"):
169-
language_model = model.model.language_model
159+
language_model, _ = get_language_model_from_vl(model)
160+
if language_model is not None:
170161
print(
171-
"Found language_model in model.model - running optimization on language model only"
162+
"Found language_model component - running optimization on language model only"
172163
)
173164

174165
if language_model is not None:

0 commit comments

Comments
 (0)