Skip to content

Commit 0fb371d

Browse files
scascarpre-commit-ci[bot]bhimrazy
authored
fix: handle gemma 3 weights prefix during hf conversion (#2156)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Bhimraj Yadav <[email protected]>
1 parent 39f9b3e commit 0fb371d

File tree

2 files changed

+16
-11
lines changed

2 files changed

+16
-11
lines changed

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from litgpt.config import Config
2020
from litgpt.utils import (
21-
_TRANSFORMERS_GREATER_EQUAL_4_52,
2221
extend_checkpoint_dir,
2322
incremental_save,
2423
lazy_load,
@@ -292,15 +291,6 @@ def copy_weights_gemma_2(
292291
pbar.update(progress_per_file)
293292

294293

295-
GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model"
296-
297-
GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower"
298-
299-
GEMMA3_MM_PROJECTOR_PREFIX = (
300-
"model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector"
301-
)
302-
303-
304294
def copy_weights_gemma_3(
305295
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
306296
state_dict: Dict[str, torch.Tensor],
@@ -312,6 +302,22 @@ def copy_weights_gemma_3(
312302
debug_mode: Optional[bool] = False,
313303
config: Optional[Config] = None,
314304
) -> None:
305+
GEMMA3_LANGUAGE_MODEL_PREFIX = (
306+
"model.language_model"
307+
if any(k.startswith("model.language_model") for k in hf_weights)
308+
else "language_model.model"
309+
)
310+
311+
GEMMA3_VISION_MODEL_PREFIX = (
312+
"model.vision_tower" if any(k.startswith("model.vision_tower") for k in hf_weights) else "vision_tower"
313+
)
314+
315+
GEMMA3_MM_PROJECTOR_PREFIX = (
316+
"model.multi_modal_projector"
317+
if any(k.startswith("model.multi_modal_projector") for k in hf_weights)
318+
else "multi_modal_projector"
319+
)
320+
315321
weight_map = {
316322
"model.embed_tokens.weight": "transformer.wte.weight",
317323
"model.layers.{}.self_attn.q_proj.weight": None,

litgpt/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
_LITDATA_AVAILABLE = RequirementCache("litdata")
4848
_LITSERVE_AVAILABLE = RequirementCache("litserve")
4949
_JINJA2_AVAILABLE = RequirementCache("jinja2")
50-
_TRANSFORMERS_GREATER_EQUAL_4_52 = RequirementCache("transformers>=4.52.0")
5150
_SAFETENSORS_AVAILABLE = RequirementCache("safetensors")
5251
_HF_TRANSFER_AVAILABLE = RequirementCache("hf_transfer")
5352

0 commit comments

Comments
 (0)