@@ -821,6 +821,18 @@ def get_vocab_base_pre(self, tokenizer) -> str:
821821 if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf" :
822822 # ref: https://huggingface.co/skt/A.X-4.0
823823 res = "a.x-4.0"
824+ if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6" :
825+ # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
826+ res = "falcon-h1"
827+ if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86" :
828+ # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
829+ res = "falcon-h1"
830+ if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896" :
831+ # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
832+ res = "falcon-h1"
833+ if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b" :
834+ # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
835+ res = "falcon-h1"
824836
825837 if res is None :
826838 logger .warning ("\n " )
@@ -4902,17 +4914,19 @@ def set_vocab(self):
49024914 def set_gguf_parameters (self ):
49034915 d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
49044916 d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
4905- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4917+ d_inner = self .find_hparam (["mamba_d_ssm" , " intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
49064918 d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
4907- head_dim = self .find_hparam (["head_dim" ], optional = True ) or 64
4919+ head_dim = self .find_hparam (["mamba_d_head" , " head_dim" ], optional = True ) or 64
49084920 n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
49094921
49104922 rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
49114923
49124924 # Fail early for models which don't have a block expansion factor of 2
49134925 # TODO: does this really matter?
4914- assert d_inner == 2 * d_model
4915- assert d_inner % head_dim == 0
4926+ # skip the assertion for FalconH1 Model
4927+ if self .model_arch != gguf .MODEL_ARCH .FALCON_H1 :
4928+ assert d_inner == 2 * d_model
4929+ assert d_inner % head_dim == 0
49164930
49174931 self .gguf_writer .add_context_length (2 ** 20 ) # arbitrary value; for those who use the default
49184932 self .gguf_writer .add_embedding_length (d_model )
@@ -4949,7 +4963,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49494963 data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
49504964 elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
49514965 d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4952- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4966+ d_inner = self .find_hparam (["mamba_d_ssm" , " intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
49534967 n_group = self .hparams .get ("n_groups" , 1 )
49544968 data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
49554969
@@ -6542,6 +6556,113 @@ def set_gguf_parameters(self):
65426556 self .gguf_writer .add_audio_stack_factor (self .global_config ["stack_factor" ])
65436557
65446558
6559+ @ModelBase .register ("FalconH1ForCausalLM" )
6560+ class FalconH1Model (Mamba2Model ):
6561+ model_arch = gguf .MODEL_ARCH .FALCON_H1
6562+
6563+ def __init__ (self , * args , ** kwargs ):
6564+ # Set the hparam prefixes for Falcon Mamba2
6565+ self .hparam_prefixes = ["mamba" ]
6566+
6567+ # Initialize the base Mamba2Model
6568+ super ().__init__ (* args , ** kwargs )
6569+
6570+ # Use Llama conversion for attention
6571+ self ._transformer_model_class = LlamaModel
6572+
6573+ # n_group and d_inner are used during reshape_tensors for mamaba2
6574+ self .n_group = self .find_hparam (["n_groups" ])
6575+ self .d_inner = self .find_hparam (["mamba_d_ssm" ])
6576+ self .d_head = self .find_hparam (["d_head" ])
6577+
6578+ # Initialize any Falcon Mamba2 specific attributes
6579+ self .has_attention = True # Falcon Mamba2 has attention components
6580+
6581+ # Load Falcon-H1 multipliers from hyperparameters
6582+ self .attention_in_multiplier = self .find_hparam (["attention_in_multiplier" ], optional = True )
6583+ self .attention_out_multiplier = self .find_hparam (["attention_out_multiplier" ], optional = True )
6584+ self .ssm_in_multiplier = self .find_hparam (["ssm_in_multiplier" ], optional = True )
6585+ self .ssm_out_multiplier = self .find_hparam (["ssm_out_multiplier" ], optional = True )
6586+ self .mlp_multipliers = self .find_hparam (["mlp_multipliers" ], optional = True )
6587+ self .ssm_multipliers = self .find_hparam (["ssm_multipliers" ], optional = True )
6588+ self .intermediate_size = self .find_hparam (["intermediate_size" ])
6589+ self .key_multiplier = self .find_hparam (["key_multiplier" ], optional = True )
6590+
6591+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6592+ prefixed = []
6593+ for pfx in self .hparam_prefixes :
6594+ prefixed .extend (
6595+ "_" .join ([pfx , k ])
6596+ for k in keys
6597+ )
6598+ keys = list (keys ) + prefixed
6599+ return super ().find_hparam (keys , * args , ** kwargs )
6600+
6601+ def set_vocab (self ):
6602+ self ._set_vocab_gpt2 ()
6603+
6604+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6605+ tensors = list (super ().modify_tensors (data_torch , name , bid ))
6606+ tensor = tensors [0 ][1 ]
6607+
6608+ if "down_proj" in name :
6609+ tensor = tensor * self .mlp_multipliers [1 ]
6610+ elif "gate_proj" in name :
6611+ tensor = tensor * self .mlp_multipliers [0 ]
6612+ elif "k_proj" in name :
6613+ tensor = tensor * self .key_multiplier * self .attention_in_multiplier
6614+ elif "q_proj" in name :
6615+ tensor = tensor * self .attention_in_multiplier
6616+ elif "v_proj" in name :
6617+ tensor = tensor * self .attention_in_multiplier
6618+ elif "o_proj" in name :
6619+ tensor = tensor * self .attention_out_multiplier
6620+ elif "out_proj" in name :
6621+ tensor = tensor * self .ssm_out_multiplier
6622+ elif "in_proj" in name :
6623+ tensor = tensor * self .ssm_in_multiplier
6624+ zxbcdt_multipliers = self .hparams ["ssm_multipliers" ]
6625+ intermediate_size = self .hparams ["mamba_d_ssm" ]
6626+ groups_time_state_size = self .hparams ["mamba_n_groups" ] * self .hparams ["mamba_d_state" ]
6627+ tensor [:intermediate_size , :] *= zxbcdt_multipliers [0 ]
6628+ tensor [intermediate_size :2 * intermediate_size , :] *= zxbcdt_multipliers [1 ]
6629+ tensor [2 * intermediate_size :2 * intermediate_size + groups_time_state_size , :] *= zxbcdt_multipliers [2 ]
6630+ tensor [2 * intermediate_size + groups_time_state_size :2 * intermediate_size + 2 * groups_time_state_size , :] *= zxbcdt_multipliers [3 ]
6631+ tensor [2 * intermediate_size + 2 * groups_time_state_size :, :] *= zxbcdt_multipliers [4 ]
6632+ elif "lm_head" in name :
6633+ tensor = tensor * self .hparams ["lm_head_multiplier" ]
6634+ elif "embed_tokens" in name :
6635+ tensor = tensor * self .hparams ["embedding_multiplier" ]
6636+ elif "mamba.norm" in name :
6637+ tensor = tensor .reshape (self .n_group , self .d_inner // self .n_group )
6638+
6639+ tensors = [(tensors [0 ][0 ], tensor )]
6640+ return tensors
6641+
6642+ def set_gguf_parameters (self ):
6643+ super ().set_gguf_parameters ()
6644+
6645+ ## General Params ##
6646+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
6647+ # Override some Mamba2 defaults
6648+ self .gguf_writer .add_block_count (self .block_count )
6649+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
6650+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
6651+
6652+ ## Attention params ##
6653+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ]) # Override value 0 from Mamba2
6654+ self .gguf_writer .add_head_count_kv (self .hparams ["num_key_value_heads" ])
6655+ self .gguf_writer .add_key_length (self .hparams ["head_dim" ])
6656+ self .gguf_writer .add_value_length (self .hparams ["head_dim" ])
6657+
6658+ ## Validation ##
6659+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6660+ assert self .d_inner % self .d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { self .d_head } "
6661+
6662+ # Add any other Falcon Mamba2 specific configuration
6663+ self .gguf_writer .add_rope_freq_base (self .find_hparam (["rope_theta" ]))
6664+
6665+
65456666@ModelBase .register ("HunYuanMoEV1ForCausalLM" )
65466667class HunYuanMoEModel (TextModel ):
65476668 model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
@@ -6695,6 +6816,16 @@ def prepare_tensors(self):
66956816class SmolLM3Model (LlamaModel ):
66966817 model_arch = gguf .MODEL_ARCH .SMOLLM3
66976818
6819+ def set_vocab (self ):
6820+ super ().set_vocab ()
6821+ # remove unsupported array slicing in chat template
6822+ # ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
6823+ from transformers import AutoTokenizer
6824+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
6825+ if tokenizer .chat_template is not None :
6826+ chat_template = tokenizer .chat_template .replace ("[:]" , "" )
6827+ self .gguf_writer .add_chat_template (chat_template )
6828+
66986829###### CONVERSION LOGIC ######
66996830
67006831
0 commit comments