diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 25052b61a..4657c0f32 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -25,7 +25,9 @@ from accelerate.hooks import remove_hook_from_module from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec from transformers import ( + AutoConfig, AutoModelForCausalLM, + AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast, WhisperProcessor, @@ -39,6 +41,7 @@ export_tensorrt_llm_checkpoint, get_model_type, ) +from modelopt.torch.export.model_utils import is_multimodal_model from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized @@ -567,19 +570,26 @@ def output_decode(generated_ids, input_shape): export_path = args.export_path - if hasattr(full_model, "language_model"): - # Save original model config and the preprocessor config to the export path for VLMs. - from transformers import AutoConfig, AutoProcessor + # Check if the model is a multimodal/VLM model + is_vlm = is_multimodal_model(full_model) - print(f"Saving original model and processor configs to {export_path}") + if is_vlm: + # Save original model config and the processor config to the export path for VLMs. + print(f"Saving original model config to {export_path}") AutoConfig.from_pretrained( args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code ).save_pretrained(export_path) - AutoProcessor.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) + # Try to save processor config if available + try: + print(f"Saving processor config to {export_path}") + AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code + ).save_pretrained(export_path) + except Exception as e: + print(f"Warning: Could not save processor config: {e}") + print("This is normal for some VLM architectures that don't use AutoProcessor") if model_type == "mllama": full_model_config = model.config diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 4bd4762e7..5ce630168 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -60,7 +60,7 @@ {MODEL_NAME_TO_TYPE=} """ -__all__ = ["get_model_type"] +__all__ = ["get_model_type", "is_multimodal_model"] def get_model_type(model): @@ -69,3 +69,43 @@ def get_model_type(model): if k.lower() in type(model).__name__.lower(): return v return None + + +def is_multimodal_model(model): + """Check if a model is a Vision-Language Model (VLM) or multimodal model. + + This function detects various multimodal model architectures by checking for: + - Standard vision configurations (vision_config) + - Language model attributes (language_model) + - Specific multimodal model types (phi4mm) + - Vision LoRA configurations + - Audio processing capabilities + - Image embedding layers + + Args: + model: The HuggingFace model instance to check + + Returns: + bool: True if the model is detected as multimodal, False otherwise + + Examples: + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> is_multimodal_model(model) + True + + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-multimodal-instruct") + >>> is_multimodal_model(model) + True + """ + config = model.config + + return ( + hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) + or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA) + or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal + or hasattr(config, "vision_lora") # Vision LoRA configurations + or hasattr(config, "audio_processor") # Audio processing capabilities + or ( + hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") + ) # Image embedding layers + )