Skip to content

Commit aa5609d

Browse files
mseegerpre-commit-ci[bot]Borda
committed
Fix in convert_hf_checkpoint related to Gemma 3 (Lightning-AI#2062)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jirka B <[email protected]> Co-authored-by: Jirka Borovec <[email protected]>
1 parent d710988 commit aa5609d

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

.azure/gpu-test.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,8 @@ jobs:
9393
env:
9494
PL_RUN_STANDALONE_TESTS: "1"
9595
# NUM_PARALLEL_TESTS: "10"
96+
NCCL_IGNORE_DISABLED_P2P: "1"
97+
NCCL_DEBUG: "INFO"
9698
timeoutInMinutes: "10"
9799
98100
- bash: |

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import torch
1515
from lightning.fabric.utilities.load import _NotYetLoadedTensor as NotYetLoadedTensor
16+
from lightning_utilities.core.imports import RequirementCache
1617
from safetensors.torch import load_file as load_safetensors
1718
from tqdm import tqdm
1819

@@ -286,11 +287,16 @@ def copy_weights_gemma_2(
286287
pbar.update(progress_per_file)
287288

288289

289-
GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model"
290+
_TRANSFORMERS_GREATER_EQUAL_4_52 = RequirementCache("transformers>=4.52.0")
290291

291-
GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower"
292+
GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model"
293+
294+
GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower"
295+
296+
GEMMA3_MM_PROJECTOR_PREFIX = (
297+
"model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector"
298+
)
292299

293-
GEMMA3_MM_PROJECTOR_PREFIX = "model.multi_modal_projector"
294300

295301
def copy_weights_gemma_3(
296302
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
@@ -325,15 +331,14 @@ def copy_weights_gemma_3(
325331
if progress_per_file is not None:
326332
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
327333
# gemma3 4b+ are multimodel models, but we are only loading the text weights
328-
is_multimodal = any(k.startswith(GEMMA3_VISION_MODEL_PREFIX) for k in hf_weights)
334+
is_multimodal = any(k.startswith(GEMMA3_LANGUAGE_MODEL_PREFIX) for k in hf_weights)
329335
if is_multimodal:
330336
warnings.warn("For Gemma3 models only the text component is supported.")
331337
new_weight_map = dict()
332-
prefix = "model."
333-
len_prefix = len(prefix)
338+
prefix = "model"
334339
for k, v in weight_map.items():
335340
if k.startswith(prefix):
336-
k = "model.language_model." + k[len_prefix:]
341+
k = GEMMA3_LANGUAGE_MODEL_PREFIX + k[len(prefix) :]
337342
new_weight_map[k] = v
338343
weight_map = new_weight_map
339344
for from_name, param in hf_weights.items():

0 commit comments

Comments
 (0)