Skip to content

Commit 78008ac

Browse files
authored
Make vlm detection more robust in ptq workflow (#286)
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent 29b4cf2 commit 78008ac

File tree

2 files changed

+58
-8
lines changed

2 files changed

+58
-8
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
from accelerate.hooks import remove_hook_from_module
2626
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
2727
from transformers import (
28+
AutoConfig,
2829
AutoModelForCausalLM,
30+
AutoProcessor,
2931
PreTrainedTokenizer,
3032
PreTrainedTokenizerFast,
3133
WhisperProcessor,
@@ -39,6 +41,7 @@
3941
export_tensorrt_llm_checkpoint,
4042
get_model_type,
4143
)
44+
from modelopt.torch.export.model_utils import is_multimodal_model
4245
from modelopt.torch.quantization.config import need_calibration
4346
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
4447
from modelopt.torch.quantization.utils import is_quantized
@@ -567,19 +570,26 @@ def output_decode(generated_ids, input_shape):
567570

568571
export_path = args.export_path
569572

570-
if hasattr(full_model, "language_model"):
571-
# Save original model config and the preprocessor config to the export path for VLMs.
572-
from transformers import AutoConfig, AutoProcessor
573+
# Check if the model is a multimodal/VLM model
574+
is_vlm = is_multimodal_model(full_model)
573575

574-
print(f"Saving original model and processor configs to {export_path}")
576+
if is_vlm:
577+
# Save original model config and the processor config to the export path for VLMs.
578+
print(f"Saving original model config to {export_path}")
575579

576580
AutoConfig.from_pretrained(
577581
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
578582
).save_pretrained(export_path)
579583

580-
AutoProcessor.from_pretrained(
581-
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
582-
).save_pretrained(export_path)
584+
# Try to save processor config if available
585+
try:
586+
print(f"Saving processor config to {export_path}")
587+
AutoProcessor.from_pretrained(
588+
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
589+
).save_pretrained(export_path)
590+
except Exception as e:
591+
print(f"Warning: Could not save processor config: {e}")
592+
print("This is normal for some VLM architectures that don't use AutoProcessor")
583593

584594
if model_type == "mllama":
585595
full_model_config = model.config

modelopt/torch/export/model_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
{MODEL_NAME_TO_TYPE=}
6161
"""
6262

63-
__all__ = ["get_model_type"]
63+
__all__ = ["get_model_type", "is_multimodal_model"]
6464

6565

6666
def get_model_type(model):
@@ -69,3 +69,43 @@ def get_model_type(model):
6969
if k.lower() in type(model).__name__.lower():
7070
return v
7171
return None
72+
73+
74+
def is_multimodal_model(model):
75+
"""Check if a model is a Vision-Language Model (VLM) or multimodal model.
76+
77+
This function detects various multimodal model architectures by checking for:
78+
- Standard vision configurations (vision_config)
79+
- Language model attributes (language_model)
80+
- Specific multimodal model types (phi4mm)
81+
- Vision LoRA configurations
82+
- Audio processing capabilities
83+
- Image embedding layers
84+
85+
Args:
86+
model: The HuggingFace model instance to check
87+
88+
Returns:
89+
bool: True if the model is detected as multimodal, False otherwise
90+
91+
Examples:
92+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
93+
>>> is_multimodal_model(model)
94+
True
95+
96+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-multimodal-instruct")
97+
>>> is_multimodal_model(model)
98+
True
99+
"""
100+
config = model.config
101+
102+
return (
103+
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
104+
or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA)
105+
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
106+
or hasattr(config, "vision_lora") # Vision LoRA configurations
107+
or hasattr(config, "audio_processor") # Audio processing capabilities
108+
or (
109+
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
110+
) # Image embedding layers
111+
)

0 commit comments

Comments
 (0)