|
13 | 13 |
|
14 | 14 | import torch
|
15 | 15 | from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
|
| 16 | +from lightning_utilities.core.imports import RequirementCache |
16 | 17 | from safetensors.torch import load_file as load_safetensors
|
17 | 18 | from tqdm import tqdm
|
18 | 19 |
|
@@ -286,6 +287,17 @@ def copy_weights_gemma_2(
|
286 | 287 | pbar.update(progress_per_file)
|
287 | 288 |
|
288 | 289 |
|
| 290 | +_TRANSFORMERS_GREATER_EQUAL_4_52 = RequirementCache("transformers>=4.52.0") |
| 291 | + |
| 292 | +GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model" |
| 293 | + |
| 294 | +GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower" |
| 295 | + |
| 296 | +GEMMA3_MM_PROJECTOR_PREFIX = ( |
| 297 | + "model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector" |
| 298 | +) |
| 299 | + |
| 300 | + |
289 | 301 | def copy_weights_gemma_3(
|
290 | 302 | qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
|
291 | 303 | state_dict: Dict[str, torch.Tensor],
|
@@ -319,15 +331,21 @@ def copy_weights_gemma_3(
|
319 | 331 | if progress_per_file is not None:
|
320 | 332 | progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
|
321 | 333 | # gemma3 4b+ are multimodel models, but we are only loading the text weights
|
322 |
| - is_multimodal = any(k.startswith("language_model") for k in hf_weights) |
| 334 | + is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights) |
323 | 335 | if is_multimodal:
|
324 | 336 | warnings.warn("For Gemma3 models only the text component is supported.")
|
325 |
| - weight_map = {f"language_model.{k}": v for k, v in weight_map.items()} |
| 337 | + new_weight_map = dict() |
| 338 | + prefix = "model" |
| 339 | + for k, v in weight_map.items(): |
| 340 | + if k.startswith(prefix): |
| 341 | + k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :] |
| 342 | + new_weight_map[k] = v |
| 343 | + weight_map = new_weight_map |
326 | 344 | for from_name, param in hf_weights.items():
|
327 |
| - if from_name.startswith("vision_tower") or from_name.startswith("multi_modal_projector"): |
| 345 | + if from_name.startswith(GEMMA3_VISION_MODEL_PREFIX) or from_name.startswith(GEMMA3_MM_PROJECTOR_PREFIX): |
328 | 346 | continue
|
329 | 347 | name_template, *ids = layer_template(from_name, num_matches=2)
|
330 |
| - to_name = weight_map[name_template] |
| 348 | + to_name = weight_map.get(name_template) |
331 | 349 | param = load_param(param, from_name, dtype, verbose=debug_mode)
|
332 | 350 | # in multimodal models, the text weights are the first part of the weights
|
333 | 351 | if is_multimodal and to_name == "transformer.wte.weight" and config is not None:
|
|
0 commit comments