Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 17 additions & 7 deletions examples/llm_ptq/hf_ptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,9 @@
from accelerate.hooks import remove_hook_from_module
from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoProcessor,
PreTrainedTokenizer,
PreTrainedTokenizerFast,
WhisperProcessor,
Expand All @@ -39,6 +41,7 @@
export_tensorrt_llm_checkpoint,
get_model_type,
)
from modelopt.torch.export.model_utils import is_multimodal_model
from modelopt.torch.quantization.config import need_calibration
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
from modelopt.torch.quantization.utils import is_quantized
Expand Down Expand Up @@ -567,19 +570,26 @@ def output_decode(generated_ids, input_shape):

export_path = args.export_path

if hasattr(full_model, "language_model"):
# Save original model config and the preprocessor config to the export path for VLMs.
from transformers import AutoConfig, AutoProcessor
# Check if the model is a multimodal/VLM model
is_vlm = is_multimodal_model(full_model)

print(f"Saving original model and processor configs to {export_path}")
if is_vlm:
# Save original model config and the processor config to the export path for VLMs.
print(f"Saving original model config to {export_path}")

AutoConfig.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)

AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
# Try to save processor config if available
try:
print(f"Saving processor config to {export_path}")
AutoProcessor.from_pretrained(
args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code
).save_pretrained(export_path)
except Exception as e:
print(f"Warning: Could not save processor config: {e}")
print("This is normal for some VLM architectures that don't use AutoProcessor")

if model_type == "mllama":
full_model_config = model.config
Expand Down
42 changes: 41 additions & 1 deletion modelopt/torch/export/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
{MODEL_NAME_TO_TYPE=}
"""

__all__ = ["get_model_type"]
__all__ = ["get_model_type", "is_multimodal_model"]


def get_model_type(model):
Expand All @@ -69,3 +69,43 @@ def get_model_type(model):
if k.lower() in type(model).__name__.lower():
return v
return None


def is_multimodal_model(model):
"""Check if a model is a Vision-Language Model (VLM) or multimodal model.

This function detects various multimodal model architectures by checking for:
- Standard vision configurations (vision_config)
- Language model attributes (language_model)
- Specific multimodal model types (phi4mm)
- Vision LoRA configurations
- Audio processing capabilities
- Image embedding layers

Args:
model: The HuggingFace model instance to check

Returns:
bool: True if the model is detected as multimodal, False otherwise

Examples:
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
>>> is_multimodal_model(model)
True

>>> model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-multimodal-instruct")
>>> is_multimodal_model(model)
True
"""
config = model.config

return (
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA)
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
or hasattr(config, "vision_lora") # Vision LoRA configurations
or hasattr(config, "audio_processor") # Audio processing capabilities
or (
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
) # Image embedding layers
)
Loading