6767 lm_head = None ,
6868)
6969
70+ TENSOR_NAMES_HF_4B = loading_utils .ModelLoader .TensorNames (
71+ ff_up_proj = "language_model.model.layers.{}.mlp.up_proj" ,
72+ ff_down_proj = "language_model.model.layers.{}.mlp.down_proj" ,
73+ ff_gate_proj = "language_model.model.layers.{}.mlp.gate_proj" ,
74+ attn_query_proj = "language_model.model.layers.{}.self_attn.q_proj" ,
75+ attn_key_proj = "language_model.model.layers.{}.self_attn.k_proj" ,
76+ attn_value_proj = "language_model.model.layers.{}.self_attn.v_proj" ,
77+ attn_output_proj = "language_model.model.layers.{}.self_attn.o_proj" ,
78+ attn_query_norm = "language_model.model.layers.{}.self_attn.q_norm" ,
79+ attn_key_norm = "language_model.model.layers.{}.self_attn.k_norm" ,
80+ pre_attn_norm = "language_model.model.layers.{}.input_layernorm" ,
81+ post_attn_norm = "language_model.model.layers.{}.post_attention_layernorm" ,
82+ pre_ff_norm = "language_model.model.layers.{}.pre_feedforward_layernorm" ,
83+ post_ff_norm = "language_model.model.layers.{}.post_feedforward_layernorm" ,
84+ embedding = "language_model.model.embed_tokens" ,
85+ final_norm = "language_model.model.norm" ,
86+ lm_head = None ,
87+ )
88+
7089TENSOR_NAMES_DICT = {
7190 "safetensors" : TENSOR_NAMES_SEP_QKV ,
7291 "kaggle" : TENSOR_NAMES_FUSED_QKV ,
92+ "hf_4b" : TENSOR_NAMES_HF_4B ,
7393}
7494
7595
@@ -445,6 +465,60 @@ def get_block_config(idx: int) -> cfg.TransformerBlockConfig:
445465 return config
446466
447467
468+ def get_decoder_config_4b () -> cfg .ModelConfig :
469+ """Returns the model config for a Gemma3 4B model."""
470+ norm_config = cfg .NormalizationConfig (
471+ type = cfg .NormalizationType .RMS_NORM , epsilon = 1e-6 , zero_centered = True ,
472+ )
473+ ff_config = cfg .FeedForwardConfig (
474+ type = cfg .FeedForwardType .GATED ,
475+ activation = cfg .ActivationConfig (cfg .ActivationType .GELU_TANH ),
476+ intermediate_size = 10240 ,
477+ pre_ff_norm_config = norm_config ,
478+ post_ff_norm_config = norm_config ,
479+ )
480+
481+ def get_block_config (idx : int ) -> cfg .TransformerBlockConfig :
482+ attn_config = cfg .AttentionConfig (
483+ num_heads = 8 ,
484+ head_dim = 256 ,
485+ num_query_groups = 1 ,
486+ rotary_base = 1_000_000 if (idx + 1 ) % 6 == 0 else 10_000 ,
487+ rotary_percentage = 1.0 ,
488+ qkv_transpose_before_split = True ,
489+ query_norm_config = norm_config ,
490+ key_norm_config = norm_config ,
491+ logit_softcap = None ,
492+ sliding_window_size = 1024 ,
493+ attn_type = (
494+ cfg .AttentionType .GLOBAL
495+ if (idx + 1 ) % 6 == 0
496+ else cfg .AttentionType .LOCAL_SLIDING
497+ ),
498+ )
499+ return cfg .TransformerBlockConfig (
500+ attn_config = attn_config ,
501+ ff_config = ff_config ,
502+ pre_attention_norm_config = norm_config ,
503+ post_attention_norm_config = norm_config ,
504+ )
505+
506+ num_layers = 34
507+ embedding_dim = 2560
508+ config = cfg .ModelConfig (
509+ vocab_size = 262_208 ,
510+ num_layers = num_layers ,
511+ max_seq_len = 32_768 ,
512+ embedding_dim = embedding_dim ,
513+ embedding_scale = embedding_dim ** 0.5 ,
514+ block_configs = [get_block_config (i ) for i in range (num_layers )],
515+ final_norm_config = norm_config ,
516+ lm_head_use_bias = False ,
517+ final_logit_softcap = None ,
518+ )
519+ return config
520+
521+
448522def get_fake_decoder_config_1b () -> cfg .ModelConfig :
449523 """Returns a fake model config for a Gemma3 1B model."""
450524 config = get_decoder_config_1b ()
@@ -481,6 +555,10 @@ def build_model_1b(
481555 )
482556 except KeyError as ke :
483557 continue
558+ raise RuntimeError (
559+ f"Failed to build model from checkpoint at { checkpoint_path } . "
560+ "None of the known tensor name mappings matched the checkpoint."
561+ )
484562
485563
486564def build_model_270m (
@@ -503,3 +581,33 @@ def build_model_270m(
503581 )
504582 except KeyError as _ :
505583 continue
584+ raise RuntimeError (
585+ f"Failed to build model from checkpoint at { checkpoint_path } . "
586+ "None of the known tensor name mappings matched the checkpoint."
587+ )
588+
589+
590+ def build_model_4b (
591+ checkpoint_path : str ,
592+ custom_loader : Callable [[str ], Dict [str , torch .Tensor ]] = None ,
593+ mask_cache_size : int = 0 ,
594+ ) -> nn .Module :
595+ """Builds a Gemma3 4B model."""
596+ # TODO(b/403644647): Better error handling for loading checkpoints with
597+ # different tensor names.
598+ for tensor_names in TENSOR_NAMES_DICT .values ():
599+ try :
600+ return model_builder .build_decoder_only_model (
601+ checkpoint_path = checkpoint_path ,
602+ config = get_decoder_config_4b (),
603+ tensor_names = tensor_names ,
604+ model_class = Decoder ,
605+ custom_loader = custom_loader ,
606+ mask_cache_size = mask_cache_size ,
607+ )
608+ except KeyError as _ :
609+ continue
610+ raise RuntimeError (
611+ f"Failed to build model from checkpoint at { checkpoint_path } . "
612+ "None of the known tensor name mappings matched the checkpoint."
613+ )
0 commit comments