From f4826ef86907e905bf8e00be1eaab58d4ae47270 Mon Sep 17 00:00:00 2001 From: Oscar Blazejewski Date: Tue, 11 Nov 2025 16:44:32 +0100 Subject: [PATCH 1/2] fix: handle gemma 3 weights prefix during hf conversion --- litgpt/scripts/convert_hf_checkpoint.py | 18 +++++++++--------- litgpt/utils.py | 1 - 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 7a39c14a58..27eb7686c7 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -18,7 +18,6 @@ from litgpt.config import Config from litgpt.utils import ( - _TRANSFORMERS_GREATER_EQUAL_4_52, extend_checkpoint_dir, incremental_save, lazy_load, @@ -292,14 +291,6 @@ def copy_weights_gemma_2( pbar.update(progress_per_file) -GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model" - -GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower" - -GEMMA3_MM_PROJECTOR_PREFIX = ( - "model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector" -) - def copy_weights_gemma_3( qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], @@ -312,6 +303,15 @@ def copy_weights_gemma_3( debug_mode: Optional[bool] = False, config: Optional[Config] = None, ) -> None: + + GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if any(k.startswith("model.language_model") for k in hf_weights) else "language_model.model" + + GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if any(k.startswith("model.vision_tower") for k in hf_weights) else "vision_tower" + + GEMMA3_MM_PROJECTOR_PREFIX = ( + "model.multi_modal_projector" if any(k.startswith("model.multi_modal_projector") for k in hf_weights) else "multi_modal_projector" + ) + weight_map = { "model.embed_tokens.weight": "transformer.wte.weight", "model.layers.{}.self_attn.q_proj.weight": None, diff --git a/litgpt/utils.py b/litgpt/utils.py index 073076dd55..303c4d1bf1 100644 --- a/litgpt/utils.py +++ b/litgpt/utils.py @@ -47,7 +47,6 @@ _LITDATA_AVAILABLE = RequirementCache("litdata") _LITSERVE_AVAILABLE = RequirementCache("litserve") _JINJA2_AVAILABLE = RequirementCache("jinja2") -_TRANSFORMERS_GREATER_EQUAL_4_52 = RequirementCache("transformers>=4.52.0") _SAFETENSORS_AVAILABLE = RequirementCache("safetensors") _HF_TRANSFER_AVAILABLE = RequirementCache("hf_transfer") From b55b7ad43e3b627605155eaa85caae933451df2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 11 Nov 2025 15:51:28 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- litgpt/scripts/convert_hf_checkpoint.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/litgpt/scripts/convert_hf_checkpoint.py b/litgpt/scripts/convert_hf_checkpoint.py index 27eb7686c7..7b266f6587 100644 --- a/litgpt/scripts/convert_hf_checkpoint.py +++ b/litgpt/scripts/convert_hf_checkpoint.py @@ -291,7 +291,6 @@ def copy_weights_gemma_2( pbar.update(progress_per_file) - def copy_weights_gemma_3( qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]], state_dict: Dict[str, torch.Tensor], @@ -303,13 +302,20 @@ def copy_weights_gemma_3( debug_mode: Optional[bool] = False, config: Optional[Config] = None, ) -> None: + GEMMA3_LANGUAGE_MODEL_PREFIX = ( + "model.language_model" + if any(k.startswith("model.language_model") for k in hf_weights) + else "language_model.model" + ) - GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if any(k.startswith("model.language_model") for k in hf_weights) else "language_model.model" - - GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if any(k.startswith("model.vision_tower") for k in hf_weights) else "vision_tower" + GEMMA3_VISION_MODEL_PREFIX = ( + "model.vision_tower" if any(k.startswith("model.vision_tower") for k in hf_weights) else "vision_tower" + ) GEMMA3_MM_PROJECTOR_PREFIX = ( - "model.multi_modal_projector" if any(k.startswith("model.multi_modal_projector") for k in hf_weights) else "multi_modal_projector" + "model.multi_modal_projector" + if any(k.startswith("model.multi_modal_projector") for k in hf_weights) + else "multi_modal_projector" ) weight_map = {