@@ -310,6 +310,8 @@ def prepare_tensors(self):
310310 gguf .MODEL_TENSOR .POSNET_NORM2 ,
311311 gguf .MODEL_TENSOR .V_ENC_EMBD_POS ,
312312 gguf .MODEL_TENSOR .A_ENC_EMBD_POS ,
313+ gguf .MODEL_TENSOR .ALTUP_CORRECT_COEF ,
314+ gguf .MODEL_TENSOR .ALTUP_PREDICT_COEF ,
313315 )
314316 )
315317 or not new_name .endswith (".weight" )
@@ -320,7 +322,11 @@ def prepare_tensors(self):
320322 self .match_model_tensor_name (new_name , key , bid )
321323 for key in (
322324 gguf .MODEL_TENSOR .TOKEN_EMBD ,
325+ gguf .MODEL_TENSOR .PER_LAYER_TOKEN_EMBD ,
323326 gguf .MODEL_TENSOR .OUTPUT ,
327+ gguf .MODEL_TENSOR .ALTUP_ROUTER ,
328+ gguf .MODEL_TENSOR .LAUREL_L ,
329+ gguf .MODEL_TENSOR .LAUREL_R ,
324330 )
325331 ):
326332 if self .ftype in (
@@ -921,13 +927,20 @@ def _create_vocab_sentencepiece(self):
921927 tokenizer = SentencePieceProcessor ()
922928 tokenizer .LoadFromFile (str (tokenizer_path ))
923929
924- vocab_size = self .hparams .get ('vocab_size' , tokenizer .vocab_size ())
930+ vocab_size = self .find_hparam ([
931+ "vocab_size_per_layer_input" , # gemma3n
932+ "vocab_size" ,
933+ ], optional = True ) or tokenizer .vocab_size ()
925934
926935 tokens : list [bytes ] = [f"[PAD{ i } ]" .encode ("utf-8" ) for i in range (vocab_size )]
927936 scores : list [float ] = [- 10000.0 ] * vocab_size
928937 toktypes : list [int ] = [SentencePieceTokenTypes .UNUSED ] * vocab_size
929938
930939 for token_id in range (tokenizer .vocab_size ()):
940+ if token_id >= vocab_size :
941+ logger .warning (f'ignore tokens from { token_id } : id is out of range, max={ vocab_size - 1 } ' )
942+ break
943+
931944 piece = tokenizer .IdToPiece (token_id )
932945 text = piece .encode ("utf-8" )
933946 score = tokenizer .GetScore (token_id )
@@ -2730,6 +2743,52 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
27302743 yield from super ().modify_tensors (data_torch , name , bid )
27312744
27322745
2746+ @ModelBase .register ("Ernie4_5_ForCausalLM" )
2747+ class Ernie4_5Model (TextModel ):
2748+ model_arch = gguf .MODEL_ARCH .ERNIE4_5
2749+
2750+ def set_vocab (self ):
2751+ self ._set_vocab_sentencepiece ()
2752+
2753+ def set_gguf_parameters (self ):
2754+ super ().set_gguf_parameters ()
2755+
2756+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2757+ num_heads = self .hparams ["num_attention_heads" ]
2758+ num_kv_heads = self .hparams ["num_key_value_heads" ]
2759+ head_dim = self .hparams ["head_dim" ]
2760+
2761+ if "ernie." in name :
2762+ name = name .replace ("ernie." , "model." )
2763+ # split the qkv weights
2764+ # qkv_proj shape: [(num_heads + 2 * num_kv_heads) * head_dim, hidden_size]
2765+ if "qkv_proj" in name :
2766+ name_q = name .replace ("qkv_proj.weight" , "q_proj.weight" )
2767+ name_k = name .replace ("qkv_proj.weight" , "k_proj.weight" )
2768+ name_v = name .replace ("qkv_proj.weight" , "v_proj.weight" )
2769+ total_q_dim = num_heads * head_dim
2770+ total_k_dim = num_kv_heads * head_dim
2771+ total_v_dim = num_kv_heads * head_dim
2772+ q_proj_weight , k_proj_weight , v_proj_weight = data_torch .split ([total_q_dim , total_k_dim , total_v_dim ], dim = 0 )
2773+ return [
2774+ (self .map_tensor_name (name_q ), q_proj_weight ),
2775+ (self .map_tensor_name (name_k ), k_proj_weight ),
2776+ (self .map_tensor_name (name_v ), v_proj_weight )
2777+ ]
2778+ # split the up_gate_proj into gate and up
2779+ # up_gate_proj shape: [2 * intermediate_size, hidden_size]
2780+ if "up_gate_proj" in name :
2781+ name_up = name .replace ("up_gate_proj.weight" , "up_proj.weight" )
2782+ name_gate = name .replace ("up_gate_proj.weight" , "gate_proj.weight" )
2783+ dim_half = data_torch .shape [0 ] // 2
2784+ gate_proj_weight , up_proj_weight = data_torch .split (dim_half , dim = 0 )
2785+ return [
2786+ (self .map_tensor_name (name_gate ), gate_proj_weight ),
2787+ (self .map_tensor_name (name_up ), up_proj_weight )
2788+ ]
2789+ return [(self .map_tensor_name (name ), data_torch )]
2790+
2791+
27332792@ModelBase .register (
27342793 "Qwen2VLModel" ,
27352794 "Qwen2VLForConditionalGeneration" ,
@@ -4217,6 +4276,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42174276@ModelBase .register ("Gemma3ForCausalLM" , "Gemma3ForConditionalGeneration" )
42184277class Gemma3Model (TextModel ):
42194278 model_arch = gguf .MODEL_ARCH .GEMMA3
4279+ norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
42204280
42214281 def set_vocab (self ):
42224282 self ._set_vocab_sentencepiece ()
@@ -4238,9 +4298,8 @@ def set_gguf_parameters(self):
42384298 self .gguf_writer .add_value_length (hparams .get ("head_dim" , 256 ))
42394299 self .gguf_writer .add_file_type (self .ftype )
42404300 self .gguf_writer .add_rope_freq_base (hparams .get ("rope_theta" , 1_000_000.0 )) # for global layers
4241- # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
4301+ # attn_logit_softcapping is removed in Gemma3
42424302 assert hparams .get ("attn_logit_softcapping" ) is None
4243- assert hparams .get ("final_logit_softcapping" ) is None
42444303 self .gguf_writer .add_sliding_window (hparams ["sliding_window" ])
42454304 self .gguf_writer .add_head_count_kv (hparams .get ("num_key_value_heads" , 4 ))
42464305 if hparams .get ("rope_scaling" ) is not None :
@@ -4252,7 +4311,7 @@ def set_gguf_parameters(self):
42524311 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
42534312 del bid # unused
42544313
4255- if name . startswith ( "language_model." ) :
4314+ if "language_model." in name :
42564315 name = name .replace ("language_model." , "" )
42574316
42584317 elif name .startswith ("multi_modal_projector." ) or name .startswith ("vision_tower." ) \
@@ -4267,8 +4326,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42674326
42684327 # ref code in Gemma3RMSNorm
42694328 # output = output * (1.0 + self.weight.float())
4329+ # note: this is not the case on gemma3n
42704330 if name .endswith ("norm.weight" ):
4271- data_torch = data_torch + 1
4331+ data_torch = data_torch + self . norm_shift
42724332
42734333 return [(self .map_tensor_name (name ), data_torch )]
42744334
@@ -4325,6 +4385,104 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43254385 return [] # skip other tensors
43264386
43274387
4388+ @ModelBase .register ("Gemma3nForConditionalGeneration" )
4389+ class Gemma3NModel (Gemma3Model ):
4390+ model_arch = gguf .MODEL_ARCH .GEMMA3N
4391+ norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
4392+
4393+ _altup_proj : list [Tensor ] = []
4394+ _altup_unembd : list [Tensor ] = []
4395+
4396+ def __init__ (self , * args , ** kwargs ):
4397+ super ().__init__ (* args , ** kwargs )
4398+ assert self .hparams ["altup_num_inputs" ] == 4 , "Current conversion only supports 4 altup inputs"
4399+ self ._altup_proj = [
4400+ torch .Tensor (), # to be replaced
4401+ torch .Tensor (), # to be replaced
4402+ torch .Tensor (), # to be replaced
4403+ ]
4404+ self ._altup_unembd = [
4405+ torch .Tensor (), # to be replaced
4406+ torch .Tensor (), # to be replaced
4407+ torch .Tensor (), # to be replaced
4408+ ]
4409+
4410+ def set_vocab (self ):
4411+ with open (self .dir_model / "chat_template.jinja" ) as f :
4412+ # quick hack to make sure chat template is added
4413+ self .gguf_writer .add_chat_template (f .read ())
4414+ super ().set_vocab ()
4415+
4416+ def set_gguf_parameters (self ):
4417+ super ().set_gguf_parameters ()
4418+ self .gguf_writer .add_altup_active_idx (self .hparams ["altup_active_idx" ])
4419+ self .gguf_writer .add_altup_num_inputs (self .hparams ["altup_num_inputs" ])
4420+ self .gguf_writer .add_embedding_length_per_layer_input (self .hparams ["hidden_size_per_layer_input" ])
4421+ self .gguf_writer .add_shared_kv_layers (self .hparams ["num_kv_shared_layers" ])
4422+
4423+ activation_sparsity_scale = []
4424+ for s in self .hparams ["activation_sparsity_pattern" ]:
4425+ normal_dist = torch .distributions .normal .Normal (0 , 1 )
4426+ std_multiplier = normal_dist .icdf (torch .tensor (s , dtype = torch .float32 ))
4427+ activation_sparsity_scale .append (std_multiplier .item ())
4428+ self .gguf_writer .add_activation_sparsity_scale (activation_sparsity_scale )
4429+
4430+ sliding_window_pattern = []
4431+ for t in self .hparams ["layer_types" ]:
4432+ sliding_window_pattern .append (t == "sliding_attention" )
4433+ self .gguf_writer .add_sliding_window_pattern (sliding_window_pattern )
4434+
4435+ def _stack_matrices (self , matrices : list [Tensor ]) -> Tensor | None :
4436+ has_all = all (m .numel () > 0 for m in matrices )
4437+ if not has_all :
4438+ return None
4439+ else :
4440+ return torch .stack (matrices , dim = 0 )
4441+
4442+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4443+ if name .endswith ("_scale" ):
4444+ name = name + ".weight"
4445+
4446+ # TODO: implement self.prediction_coefs.weight.clamp_(...)
4447+
4448+ if "language_model." not in name :
4449+ return [] # skip non-language model tensors
4450+
4451+ if "altup_unembed_projections" in name :
4452+ data_torch = data_torch .to (device = "cpu" )
4453+ if ".0." in name :
4454+ self ._altup_unembd [0 ] = data_torch
4455+ elif ".1." in name :
4456+ self ._altup_unembd [1 ] = data_torch
4457+ elif ".2." in name :
4458+ self ._altup_unembd [2 ] = data_torch
4459+ else :
4460+ raise ValueError (f"Unknown name: { name } " )
4461+ out = self ._stack_matrices (self ._altup_unembd )
4462+ if out is not None :
4463+ return [(self .map_tensor_name ("model.altup_unembed_projections.weight" ), out )]
4464+ else :
4465+ return []
4466+
4467+ if "altup_projections" in name :
4468+ data_torch = data_torch .to (device = "cpu" )
4469+ if ".0." in name :
4470+ self ._altup_proj [0 ] = data_torch
4471+ elif ".1." in name :
4472+ self ._altup_proj [1 ] = data_torch
4473+ elif ".2." in name :
4474+ self ._altup_proj [2 ] = data_torch
4475+ else :
4476+ raise ValueError (f"Unknown name: { name } " )
4477+ out = self ._stack_matrices (self ._altup_proj )
4478+ if out is not None :
4479+ return [(self .map_tensor_name ("model.altup_projections.weight" ), out )]
4480+ else :
4481+ return []
4482+
4483+ return super ().modify_tensors (data_torch , name , bid )
4484+
4485+
43284486@ModelBase .register ("Starcoder2ForCausalLM" )
43294487class StarCoder2Model (TextModel ):
43304488 model_arch = gguf .MODEL_ARCH .STARCODER2
0 commit comments