|
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 | from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
|
| 16 | +from lightning_utilities.core.imports import RequirementCache |
16 | 17 | from safetensors.torch import load_file as load_safetensors
|
17 | 18 | from tqdm import tqdm
|
18 | 19 |
|
@@ -286,11 +287,16 @@ def copy_weights_gemma_2(
|
286 | 287 | pbar.update(progress_per_file)
|
287 | 288 |
|
288 | 289 |
|
289 |
| -GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" |
| 290 | +_TRANSFORMERS_GREATER_EQUAL_4_52 = RequirementCache("transformers>=4.52.0") |
290 | 291 |
|
291 |
| -GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" |
| 292 | +GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model" |
| 293 | + |
| 294 | +GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower" |
| 295 | + |
| 296 | +GEMMA3_MM_PROJECTOR_PREFIX = ( |
| 297 | + "model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector" |
| 298 | +) |
292 | 299 |
|
293 |
| -GEMMA3_MM_PROJECTOR_PREFIX = "model.multi_modal_projector" |
294 | 300 |
|
295 | 301 | def copy_weights_gemma_3(
|
296 | 302 | qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
|
@@ -325,15 +331,14 @@ def copy_weights_gemma_3(
|
325 | 331 | if progress_per_file is not None:
|
326 | 332 | progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
|
327 | 333 | # gemma3 4b+ are multimodel models, but we are only loading the text weights
|
328 |
| - is_multimodal = any(k.startswith(GEMMA3_VISION_MODEL_PREFIX) for k in hf_weights) |
| 334 | + is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights) |
329 | 335 | if is_multimodal:
|
330 | 336 | warnings.warn("For Gemma3 models only the text component is supported.")
|
331 | 337 | new_weight_map = dict()
|
332 |
| - prefix = "model." |
333 |
| - len_prefix = len(prefix) |
| 338 | + prefix = "model" |
334 | 339 | for k, v in weight_map.items():
|
335 | 340 | if k.startswith(prefix):
|
336 |
| - k = "model.language_model." + k[len_prefix:] |
| 341 | + k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :] |
337 | 342 | new_weight_map[k] = v
|
338 | 343 | weight_map = new_weight_map
|
339 | 344 | for from_name, param in hf_weights.items():
|
|
0 commit comments