|
60 | 60 | {MODEL_NAME_TO_TYPE=} |
61 | 61 | """ |
62 | 62 |
|
63 | | -__all__ = ["get_model_type", "is_multimodal_model"] |
| 63 | +__all__ = ["get_model_type", "is_multimodal_model", "get_language_model_from_vl"] |
64 | 64 |
|
65 | 65 |
|
66 | 66 | def get_model_type(model): |
@@ -109,3 +109,49 @@ def is_multimodal_model(model): |
109 | 109 | hasattr(config, "embd_layer") and hasattr(config.embd_layer, "image_embd_layer") |
110 | 110 | ) # Image embedding layers |
111 | 111 | ) |
| 112 | + |
| 113 | + |
| 114 | +def get_language_model_from_vl(model): |
| 115 | + """Extract the language model component from a Vision-Language Model (VLM). |
| 116 | +
|
| 117 | + This function handles the common patterns for accessing the language model component |
| 118 | + in various VLM architectures. It checks multiple possible locations where the |
| 119 | + language model might be stored. |
| 120 | +
|
| 121 | + Args: |
| 122 | + model: The VLM model instance to extract the language model from |
| 123 | +
|
| 124 | + Returns: |
| 125 | + tuple: (language_model, parent_model) where: |
| 126 | + - language_model: The extracted language model component, or None if not found |
| 127 | + - parent_model: The parent model containing the language_model attribute |
| 128 | +
|
| 129 | + Examples: |
| 130 | + >>> # For LLaVA-style models |
| 131 | + >>> lang_model, parent = get_language_model_from_vl(vlm_model) |
| 132 | + >>> if lang_model is not None: |
| 133 | + ... # Work with the language model component |
| 134 | + ... quantized_lang_model = quantize(lang_model) |
| 135 | + ... # Update the parent model |
| 136 | + ... parent.language_model = quantized_lang_model |
| 137 | + """ |
| 138 | + # Pattern 1: Direct language_model attribute (e.g., LLaVA, some Nemotron models) |
| 139 | + if hasattr(model, "language_model"): |
| 140 | + # Check if it's a property that might need special handling |
| 141 | + if isinstance(type(model).__dict__.get("language_model"), property): |
| 142 | + # Some models have language_model as a property that points to model.model.language_model |
| 143 | + if hasattr(model, "model") and hasattr(model.model, "language_model"): |
| 144 | + return model.model.language_model, model.model |
| 145 | + else: |
| 146 | + # Property exists but no nested structure found |
| 147 | + return model.language_model, model |
| 148 | + else: |
| 149 | + # Direct attribute access |
| 150 | + return model.language_model, model |
| 151 | + |
| 152 | + # Pattern 2: Nested in model.model.language_model (e.g., some Gemma3, Qwen2.5-VL models) |
| 153 | + elif hasattr(model, "model") and hasattr(model.model, "language_model"): |
| 154 | + return model.model.language_model, model.model |
| 155 | + |
| 156 | + # Pattern 3: No language_model found |
| 157 | + return None, None |
0 commit comments