@@ -310,6 +310,8 @@ def prepare_tensors(self):
310310 gguf .MODEL_TENSOR .POSNET_NORM2 ,
311311 gguf .MODEL_TENSOR .V_ENC_EMBD_POS ,
312312 gguf .MODEL_TENSOR .A_ENC_EMBD_POS ,
313+ gguf .MODEL_TENSOR .ALTUP_CORRECT_COEF ,
314+ gguf .MODEL_TENSOR .ALTUP_PREDICT_COEF ,
313315 )
314316 )
315317 or not new_name .endswith (".weight" )
@@ -320,7 +322,11 @@ def prepare_tensors(self):
320322 self .match_model_tensor_name (new_name , key , bid )
321323 for key in (
322324 gguf .MODEL_TENSOR .TOKEN_EMBD ,
325+ gguf .MODEL_TENSOR .PER_LAYER_TOKEN_EMBD ,
323326 gguf .MODEL_TENSOR .OUTPUT ,
327+ gguf .MODEL_TENSOR .ALTUP_ROUTER ,
328+ gguf .MODEL_TENSOR .LAUREL_L ,
329+ gguf .MODEL_TENSOR .LAUREL_R ,
324330 )
325331 ):
326332 if self .ftype in (
@@ -921,13 +927,16 @@ def _create_vocab_sentencepiece(self):
921927 tokenizer = SentencePieceProcessor ()
922928 tokenizer .LoadFromFile (str (tokenizer_path ))
923929
924- vocab_size = self .hparams .get ('vocab_size' , tokenizer .vocab_size ())
930+ vocab_size = self .find_hparam ([
931+ "vocab_size_per_layer_input" , # gemma3n
932+ "vocab_size" ,
933+ ], optional = True ) or tokenizer .vocab_size ()
925934
926935 tokens : list [bytes ] = [f"[PAD{ i } ]" .encode ("utf-8" ) for i in range (vocab_size )]
927936 scores : list [float ] = [- 10000.0 ] * vocab_size
928937 toktypes : list [int ] = [SentencePieceTokenTypes .UNUSED ] * vocab_size
929938
930- for token_id in range (tokenizer . vocab_size () ):
939+ for token_id in range (vocab_size ):
931940 piece = tokenizer .IdToPiece (token_id )
932941 text = piece .encode ("utf-8" )
933942 score = tokenizer .GetScore (token_id )
@@ -942,6 +951,10 @@ def _create_vocab_sentencepiece(self):
942951 elif tokenizer .IsByte (token_id ):
943952 toktype = SentencePieceTokenTypes .BYTE
944953
954+ if token_id >= vocab_size :
955+ logger .warning (f'ignore tokens from { token_id } : id is out of range, max={ vocab_size - 1 } ' )
956+ break
957+
945958 tokens [token_id ] = text
946959 scores [token_id ] = score
947960 toktypes [token_id ] = toktype
@@ -4217,6 +4230,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42174230@ModelBase .register ("Gemma3ForCausalLM" , "Gemma3ForConditionalGeneration" )
42184231class Gemma3Model (TextModel ):
42194232 model_arch = gguf .MODEL_ARCH .GEMMA3
4233+ norm_shift = 1.0 # Gemma3RMSNorm adds 1.0 to the norm value
42204234
42214235 def set_vocab (self ):
42224236 self ._set_vocab_sentencepiece ()
@@ -4238,9 +4252,8 @@ def set_gguf_parameters(self):
42384252 self .gguf_writer .add_value_length (hparams .get ("head_dim" , 256 ))
42394253 self .gguf_writer .add_file_type (self .ftype )
42404254 self .gguf_writer .add_rope_freq_base (hparams .get ("rope_theta" , 1_000_000.0 )) # for global layers
4241- # both attn_logit_softcapping and final_logit_softcapping are removed in Gemma3
4255+ # attn_logit_softcapping is removed in Gemma3
42424256 assert hparams .get ("attn_logit_softcapping" ) is None
4243- assert hparams .get ("final_logit_softcapping" ) is None
42444257 self .gguf_writer .add_sliding_window (hparams ["sliding_window" ])
42454258 self .gguf_writer .add_head_count_kv (hparams .get ("num_key_value_heads" , 4 ))
42464259 if hparams .get ("rope_scaling" ) is not None :
@@ -4252,7 +4265,7 @@ def set_gguf_parameters(self):
42524265 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
42534266 del bid # unused
42544267
4255- if name . startswith ( "language_model." ) :
4268+ if "language_model." in name :
42564269 name = name .replace ("language_model." , "" )
42574270
42584271 elif name .startswith ("multi_modal_projector." ) or name .startswith ("vision_tower." ) \
@@ -4267,8 +4280,9 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42674280
42684281 # ref code in Gemma3RMSNorm
42694282 # output = output * (1.0 + self.weight.float())
4283+ # note: this is not the case on gemma3n
42704284 if name .endswith ("norm.weight" ):
4271- data_torch = data_torch + 1
4285+ data_torch = data_torch + self . norm_shift
42724286
42734287 return [(self .map_tensor_name (name ), data_torch )]
42744288
@@ -4325,6 +4339,104 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43254339 return [] # skip other tensors
43264340
43274341
4342+ @ModelBase .register ("Gemma3nForConditionalGeneration" )
4343+ class Gemma3NModel (Gemma3Model ):
4344+ model_arch = gguf .MODEL_ARCH .GEMMA3N
4345+ norm_shift = 0.0 # same value with Gemma3p5RMSNorm scale_shift on python code
4346+
4347+ _altup_proj : list [Tensor ] = []
4348+ _altup_unembd : list [Tensor ] = []
4349+
4350+ def __init__ (self , * args , ** kwargs ):
4351+ super ().__init__ (* args , ** kwargs )
4352+ assert self .hparams ["altup_num_inputs" ] == 4 , "Current conversion only supports 4 altup inputs"
4353+ self ._altup_proj = [
4354+ torch .Tensor (), # to be replaced
4355+ torch .Tensor (), # to be replaced
4356+ torch .Tensor (), # to be replaced
4357+ ]
4358+ self ._altup_unembd = [
4359+ torch .Tensor (), # to be replaced
4360+ torch .Tensor (), # to be replaced
4361+ torch .Tensor (), # to be replaced
4362+ ]
4363+
4364+ def set_vocab (self ):
4365+ with open (self .dir_model / "chat_template.jinja" ) as f :
4366+ # quick hack to make sure chat template is added
4367+ self .gguf_writer .add_chat_template (f .read ())
4368+ super ().set_vocab ()
4369+
4370+ def set_gguf_parameters (self ):
4371+ super ().set_gguf_parameters ()
4372+ self .gguf_writer .add_altup_active_idx (self .hparams ["altup_active_idx" ])
4373+ self .gguf_writer .add_altup_num_inputs (self .hparams ["altup_num_inputs" ])
4374+ self .gguf_writer .add_embedding_length_per_layer_input (self .hparams ["hidden_size_per_layer_input" ])
4375+ self .gguf_writer .add_shared_kv_layers (self .hparams ["num_kv_shared_layers" ])
4376+
4377+ activation_sparsity_scale = []
4378+ for s in self .hparams ["activation_sparsity_pattern" ]:
4379+ normal_dist = torch .distributions .normal .Normal (0 , 1 )
4380+ std_multiplier = normal_dist .icdf (torch .tensor (s , dtype = torch .float32 ))
4381+ activation_sparsity_scale .append (std_multiplier .item ())
4382+ self .gguf_writer .add_activation_sparsity_scale (activation_sparsity_scale )
4383+
4384+ sliding_window_pattern = []
4385+ for t in self .hparams ["layer_types" ]:
4386+ sliding_window_pattern .append (t == "sliding_attention" )
4387+ self .gguf_writer .add_sliding_window_pattern (sliding_window_pattern )
4388+
4389+ def _stack_matrices (self , matrices : list [Tensor ]) -> Tensor | None :
4390+ has_all = all (m .numel () > 0 for m in matrices )
4391+ if not has_all :
4392+ return None
4393+ else :
4394+ return torch .stack (matrices , dim = 0 )
4395+
4396+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
4397+ if name .endswith ("_scale" ):
4398+ name = name + ".weight"
4399+
4400+ # TODO: implement self.prediction_coefs.weight.clamp_(...)
4401+
4402+ if "language_model." not in name :
4403+ return [] # skip non-language model tensors
4404+
4405+ if "altup_unembed_projections" in name :
4406+ data_torch = data_torch .to (device = "cpu" )
4407+ if ".0." in name :
4408+ self ._altup_unembd [0 ] = data_torch
4409+ elif ".1." in name :
4410+ self ._altup_unembd [1 ] = data_torch
4411+ elif ".2." in name :
4412+ self ._altup_unembd [2 ] = data_torch
4413+ else :
4414+ raise ValueError (f"Unknown name: { name } " )
4415+ out = self ._stack_matrices (self ._altup_unembd )
4416+ if out is not None :
4417+ return [(self .map_tensor_name ("model.altup_unembed_projections.weight" ), out )]
4418+ else :
4419+ return []
4420+
4421+ if "altup_projections" in name :
4422+ data_torch = data_torch .to (device = "cpu" )
4423+ if ".0." in name :
4424+ self ._altup_proj [0 ] = data_torch
4425+ elif ".1." in name :
4426+ self ._altup_proj [1 ] = data_torch
4427+ elif ".2." in name :
4428+ self ._altup_proj [2 ] = data_torch
4429+ else :
4430+ raise ValueError (f"Unknown name: { name } " )
4431+ out = self ._stack_matrices (self ._altup_proj )
4432+ if out is not None :
4433+ return [(self .map_tensor_name ("model.altup_projections.weight" ), out )]
4434+ else :
4435+ return []
4436+
4437+ return super ().modify_tensors (data_torch , name , bid )
4438+
4439+
43284440@ModelBase .register ("Starcoder2ForCausalLM" )
43294441class StarCoder2Model (TextModel ):
43304442 model_arch = gguf .MODEL_ARCH .STARCODER2
0 commit comments