Skip to content

Commit 1fd1154

Browse files
committed
minor, add a util function
Signed-off-by: Zhiyu Cheng <[email protected]>
1 parent fb5276b commit 1fd1154

File tree

2 files changed

+44
-13
lines changed

2 files changed

+44
-13
lines changed

examples/llm_ptq/hf_ptq.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
export_tensorrt_llm_checkpoint,
4242
get_model_type,
4343
)
44+
from modelopt.torch.export.model_utils import is_multimodal_model
4445
from modelopt.torch.quantization.config import need_calibration
4546
from modelopt.torch.quantization.plugins.accelerate import init_quantized_weights
4647
from modelopt.torch.quantization.utils import is_quantized
@@ -569,18 +570,8 @@ def output_decode(generated_ids, input_shape):
569570

570571
export_path = args.export_path
571572

572-
# Check for VLMs by looking for various multimodal indicators in model config
573-
config = full_model.config
574-
is_vlm = (
575-
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
576-
or hasattr(full_model, "language_model") # Language model attribute (e.g., LLaVA)
577-
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
578-
or hasattr(config, "vision_lora") # Vision LoRA configurations
579-
or hasattr(config, "audio_processor") # Audio processing capabilities
580-
or (
581-
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
582-
) # Image embedding layers
583-
)
573+
# Check if the model is a multimodal/VLM model
574+
is_vlm = is_multimodal_model(full_model)
584575

585576
if is_vlm:
586577
# Save original model config and the processor config to the export path for VLMs.

modelopt/torch/export/model_utils.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
{MODEL_NAME_TO_TYPE=}
6161
"""
6262

63-
__all__ = ["get_model_type"]
63+
__all__ = ["get_model_type", "is_multimodal_model"]
6464

6565

6666
def get_model_type(model):
@@ -69,3 +69,43 @@ def get_model_type(model):
6969
if k.lower() in type(model).__name__.lower():
7070
return v
7171
return None
72+
73+
74+
def is_multimodal_model(model):
75+
"""Check if a model is a Vision-Language Model (VLM) or multimodal model.
76+
77+
This function detects various multimodal model architectures by checking for:
78+
- Standard vision configurations (vision_config)
79+
- Language model attributes (language_model)
80+
- Specific multimodal model types (phi4mm)
81+
- Vision LoRA configurations
82+
- Audio processing capabilities
83+
- Image embedding layers
84+
85+
Args:
86+
model: The HuggingFace model instance to check
87+
88+
Returns:
89+
bool: True if the model is detected as multimodal, False otherwise
90+
91+
Examples:
92+
>>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
93+
>>> is_multimodal_model(model)
94+
True
95+
96+
>>> model = AutoModelForCausalLM.from_pretrained("microsoft/Phi-4-multimodal-instruct")
97+
>>> is_multimodal_model(model)
98+
True
99+
"""
100+
config = model.config
101+
102+
return (
103+
hasattr(config, "vision_config") # Standard vision config (e.g., Qwen2.5-VL)
104+
or hasattr(model, "language_model") # Language model attribute (e.g., LLaVA)
105+
or getattr(config, "model_type", "") == "phi4mm" # Phi-4 multimodal
106+
or hasattr(config, "vision_lora") # Vision LoRA configurations
107+
or hasattr(config, "audio_processor") # Audio processing capabilities
108+
or (
109+
hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer")
110+
) # Image embedding layers
111+
)

0 commit comments

Comments
 (0)