1818
1919from litgpt .config import Config
2020from litgpt .utils import (
21- _TRANSFORMERS_GREATER_EQUAL_4_52 ,
2221 extend_checkpoint_dir ,
2322 incremental_save ,
2423 lazy_load ,
@@ -292,15 +291,6 @@ def copy_weights_gemma_2(
292291 pbar .update (progress_per_file )
293292
294293
295- GEMMA3_LANGUAGE_MODEL_PREFIX = "model.language_model" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "language_model.model"
296-
297- GEMMA3_VISION_MODEL_PREFIX = "model.vision_tower" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "vision_tower"
298-
299- GEMMA3_MM_PROJECTOR_PREFIX = (
300- "model.multi_modal_projector" if _TRANSFORMERS_GREATER_EQUAL_4_52 else "multi_modal_projector"
301- )
302-
303-
304294def copy_weights_gemma_3 (
305295 qkv_weights : Dict [int , List [Optional [NotYetLoadedTensor ]]],
306296 state_dict : Dict [str , torch .Tensor ],
@@ -312,6 +302,22 @@ def copy_weights_gemma_3(
312302 debug_mode : Optional [bool ] = False ,
313303 config : Optional [Config ] = None ,
314304) -> None :
305+ GEMMA3_LANGUAGE_MODEL_PREFIX = (
306+ "model.language_model"
307+ if any (k .startswith ("model.language_model" ) for k in hf_weights )
308+ else "language_model.model"
309+ )
310+
311+ GEMMA3_VISION_MODEL_PREFIX = (
312+ "model.vision_tower" if any (k .startswith ("model.vision_tower" ) for k in hf_weights ) else "vision_tower"
313+ )
314+
315+ GEMMA3_MM_PROJECTOR_PREFIX = (
316+ "model.multi_modal_projector"
317+ if any (k .startswith ("model.multi_modal_projector" ) for k in hf_weights )
318+ else "multi_modal_projector"
319+ )
320+
315321 weight_map = {
316322 "model.embed_tokens.weight" : "transformer.wte.weight" ,
317323 "model.layers.{}.self_attn.q_proj.weight" : None ,
0 commit comments