Skip to content

Commit 3d33a05

Browse files
mseegerpre-commit-ci[bot]Borda
authored
Fix in convert_hf_checkpoint related to Gemma 3 (#2062)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent cd6499e commit 3d33a05

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

.azure/gpu-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ jobs:
9393
env:
9494
PL_RUN_STANDALONE_TESTS: "1"
9595
# NUM_PARALLEL_TESTS: "10"
96+
NCCL_IGNORE_DISABLED_P2P: "1"
97+
NCCL_DEBUG: "INFO"
9698
timeoutInMinutes: "10"
9799
98100
- bash: |

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
16+
from lightning_utilities.core.imports import RequirementCache
1617
from safetensors.torch import load_file as load_safetensors
1718
from tqdm import tqdm
1819

@@ -286,6 +287,17 @@ def copy_weights_gemma_2(
286287
pbar.update(progress_per_file)
287288

288289

290+
_TRANSFORMERS_GREATER_EQUAL_4_52 = RequirementCache("transformers>=4.52.0")
291+
292+
GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model"
293+
294+
GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower"
295+
296+
GEMMA3_MM_PROJECTOR_PREFIX = (
297+
"model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector"
298+
)
299+
300+
289301
def copy_weights_gemma_3(
290302
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
291303
state_dict: Dict[str, torch.Tensor],
@@ -319,15 +331,21 @@ def copy_weights_gemma_3(
319331
if progress_per_file is not None:
320332
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
321333
# gemma3 4b+ are multimodel models, but we are only loading the text weights
322-
is_multimodal = any(k.startswith("language_model") for k in hf_weights)
334+
is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights)
323335
if is_multimodal:
324336
warnings.warn("For Gemma3 models only the text component is supported.")
325-
weight_map = {f"language_model.{k}": v for k, v in weight_map.items()}
337+
new_weight_map = dict()
338+
prefix = "model"
339+
for k, v in weight_map.items():
340+
if k.startswith(prefix):
341+
k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :]
342+
new_weight_map[k] = v
343+
weight_map = new_weight_map
326344
for from_name, param in hf_weights.items():
327-
if from_name.startswith("vision_tower") or from_name.startswith("multi_modal_projector"):
345+
if from_name.startswith(GEMMA3_VISION_MODEL_PREFIX) or from_name.startswith(GEMMA3_MM_PROJECTOR_PREFIX):
328346
continue
329347
name_template, *ids = layer_template(from_name, num_matches=2)
330-
to_name = weight_map[name_template]
348+
to_name = weight_map.get(name_template)
331349
param = load_param(param, from_name, dtype, verbose=debug_mode)
332350
# in multimodal models, the text weights are the first part of the weights
333351
if is_multimodal and to_name == "transformer.wte.weight" and config is not None:

0 commit comments

Comments
 (0)