From c0d16612fe76100a22d669feea89ab67dfe75bcc Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 4 Sep 2025 18:27:47 +0000 Subject: [PATCH 1/3] make vlm detection more robust in ptq workflow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 25052b61a..7f87dc0b1 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -25,7 +25,9 @@ from accelerate.hooks import remove_hook_from_module from example_utils import apply_kv_cache_quant, get_model, get_processor, get_tokenizer, is_enc_dec from transformers import ( + AutoConfig, AutoModelForCausalLM, + AutoProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast, WhisperProcessor, @@ -567,9 +569,18 @@ def output_decode(generated_ids, input_shape): export_path = args.export_path - if hasattr(full_model, "language_model"): + # Check for VLMs by looking for vision_config in model config or language_model attribute + is_vlm = False + try: + is_vlm = hasattr(full_model.config, "vision_config") or hasattr( + full_model, "language_model" + ) + except Exception: + # Fallback to the original check if config access fails + is_vlm = hasattr(full_model, "language_model") + + if is_vlm: # Save original model config and the preprocessor config to the export path for VLMs. - from transformers import AutoConfig, AutoProcessor print(f"Saving original model and processor configs to {export_path}") From 3f3e23f7d8994e44cfed6802fe3e32563a1db31e Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Thu, 4 Sep 2025 20:54:45 +0000 Subject: [PATCH 2/3] make vlm detection more robust in ptq workflow Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 38 +++++++++++++++++++++++--------------- 1 file changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 7f87dc0b1..5501e97aa 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -569,28 +569,36 @@ def output_decode(generated_ids, input_shape): export_path = args.export_path - # Check for VLMs by looking for vision_config in model config or language_model attribute - is_vlm = False - try: - is_vlm = hasattr(full_model.config, "vision_config") or hasattr( - full_model, "language_model" - ) - except Exception: - # Fallback to the original check if config access fails - is_vlm = hasattr(full_model, "language_model") + # Check for VLMs by looking for various multimodal indicators in model config + config = full_model.config + is_vlm = ( + hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) + or hasattr(full_model, "language_model") # Language model attribute (e.g., LLaVA) + or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal + or hasattr(config, "vision_lora") # Vision LoRA configurations + or hasattr(config, "audio_processor") # Audio processing capabilities + or ( + hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") + ) # Image embedding layers + ) if is_vlm: - # Save original model config and the preprocessor config to the export path for VLMs. - - print(f"Saving original model and processor configs to {export_path}") + # Save original model config and the processor config to the export path for VLMs. + print(f"Saving original model config to {export_path}") AutoConfig.from_pretrained( args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code ).save_pretrained(export_path) - AutoProcessor.from_pretrained( - args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code - ).save_pretrained(export_path) + # Try to save processor config if available + try: + print(f"Saving processor config to {export_path}") + AutoProcessor.from_pretrained( + args.pyt_ckpt_path, trust_remote_code=args.trust_remote_code + ).save_pretrained(export_path) + except Exception as e: + print(f"Warning: Could not save processor config: {e}") + print("This is normal for some VLM architectures that don't use AutoProcessor") if model_type == "mllama": full_model_config = model.config From 2e75ee38c70cb4688586f41a3b66f2506fe6eea2 Mon Sep 17 00:00:00 2001 From: Zhiyu Cheng Date: Fri, 5 Sep 2025 00:21:50 +0000 Subject: [PATCH 3/3] minor, add a util function Signed-off-by: Zhiyu Cheng --- examples/llm_ptq/hf_ptq.py | 15 ++-------- modelopt/torch/export/model_utils.py | 42 +++++++++++++++++++++++++++- 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/examples/llm_ptq/hf_ptq.py b/examples/llm_ptq/hf_ptq.py index 5501e97aa..4657c0f32 100755 --- a/examples/llm_ptq/hf_ptq.py +++ b/examples/llm_ptq/hf_ptq.py @@ -41,6 +41,7 @@ export_tensorrt_llm_checkpoint, get_model_type, ) +from modelopt.torch.export.model_utils import is_multimodal_model from modelopt.torch.quantization.config import need_calibration from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights from modelopt.torch.quantization.utils import is_quantized @@ -569,18 +570,8 @@ def output_decode(generated_ids, input_shape): export_path = args.export_path - # Check for VLMs by looking for various multimodal indicators in model config - config = full_model.config - is_vlm = ( - hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) - or hasattr(full_model, "language_model") # Language model attribute (e.g., LLaVA) - or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal - or hasattr(config, "vision_lora") # Vision LoRA configurations - or hasattr(config, "audio_processor") # Audio processing capabilities - or ( - hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") - ) # Image embedding layers - ) + # Check if the model is a multimodal/VLM model + is_vlm = is_multimodal_model(full_model) if is_vlm: # Save original model config and the processor config to the export path for VLMs. diff --git a/modelopt/torch/export/model_utils.py b/modelopt/torch/export/model_utils.py index 4bd4762e7..5ce630168 100755 --- a/modelopt/torch/export/model_utils.py +++ b/modelopt/torch/export/model_utils.py @@ -60,7 +60,7 @@ {MODEL_NAME_TO_TYPE=} """ -__all__ = ["get_model_type"] +__all__ = ["get_model_type", "is_multimodal_model"] def get_model_type(model): @@ -69,3 +69,43 @@ def get_model_type(model): if k.lower() in type(model).__name__.lower(): return v return None + + +def is_multimodal_model(model): + """Check if a model is a Vision-Language Model (VLM) or multimodal model. + + This function detects various multimodal model architectures by checking for: + - Standard vision configurations (vision_config) + - Language model attributes (language_model) + - Specific multimodal model types (phi4mm) + - Vision LoRA configurations + - Audio processing capabilities + - Image embedding layers + + Args: + model: The HuggingFace model instance to check + + Returns: + bool: True if the model is detected as multimodal, False otherwise + + Examples: + >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct") + >>> is_multimodal_model(model) + True + + >>> model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-multimodal-instruct") + >>> is_multimodal_model(model) + True + """ + config = model.config + + return ( + hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL) + or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA) + or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal + or hasattr(config, "vision_lora") # Vision LoRA configurations + or hasattr(config, "audio_processor") # Audio processing capabilities + or ( + hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") + ) # Image embedding layers + )