@@ -818,6 +818,21 @@ def get_vocab_base_pre(self, tokenizer) -> str:
818818 if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664" :
819819 # ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
820820 res = "hunyuan"
821+ if chkhsh == "b0a6b1c0bd5998ebd9df08611efde34a4ff03faed45ae09c43e6b31ebd4b94cf" :
822+ # ref: https://huggingface.co/skt/A.X-4.0
823+ res = "a.x-4.0"
824+ if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6" :
825+ # ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
826+ res = "falcon-h1"
827+ if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86" :
828+ # ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
829+ res = "falcon-h1"
830+ if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896" :
831+ # ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
832+ res = "falcon-h1"
833+ if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b" :
834+ # ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
835+ res = "falcon-h1"
821836
822837 if res is None :
823838 logger .warning ("\n " )
@@ -4899,17 +4914,19 @@ def set_vocab(self):
48994914 def set_gguf_parameters (self ):
49004915 d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
49014916 d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
4902- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4917+ d_inner = self .find_hparam (["mamba_d_ssm" , " intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
49034918 d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
4904- head_dim = self .find_hparam (["head_dim" ], optional = True ) or 64
4919+ head_dim = self .find_hparam (["mamba_d_head" , " head_dim" ], optional = True ) or 64
49054920 n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
49064921
49074922 rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
49084923
49094924 # Fail early for models which don't have a block expansion factor of 2
49104925 # TODO: does this really matter?
4911- assert d_inner == 2 * d_model
4912- assert d_inner % head_dim == 0
4926+ # skip the assertion for FalconH1 Model
4927+ if self .model_arch != gguf .MODEL_ARCH .FALCON_H1 :
4928+ assert d_inner == 2 * d_model
4929+ assert d_inner % head_dim == 0
49134930
49144931 self .gguf_writer .add_context_length (2 ** 20 ) # arbitrary value; for those who use the default
49154932 self .gguf_writer .add_embedding_length (d_model )
@@ -4946,7 +4963,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49464963 data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
49474964 elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
49484965 d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4949- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4966+ d_inner = self .find_hparam (["mamba_d_ssm" , " intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
49504967 n_group = self .hparams .get ("n_groups" , 1 )
49514968 data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
49524969
@@ -6656,6 +6673,113 @@ def set_gguf_parameters(self):
66566673 self .gguf_writer .add_audio_stack_factor (self .global_config ["stack_factor" ])
66576674
66586675
6676+ @ModelBase .register ("FalconH1ForCausalLM" )
6677+ class FalconH1Model (Mamba2Model ):
6678+ model_arch = gguf .MODEL_ARCH .FALCON_H1
6679+
6680+ def __init__ (self , * args , ** kwargs ):
6681+ # Set the hparam prefixes for Falcon Mamba2
6682+ self .hparam_prefixes = ["mamba" ]
6683+
6684+ # Initialize the base Mamba2Model
6685+ super ().__init__ (* args , ** kwargs )
6686+
6687+ # Use Llama conversion for attention
6688+ self ._transformer_model_class = LlamaModel
6689+
6690+ # n_group and d_inner are used during reshape_tensors for mamaba2
6691+ self .n_group = self .find_hparam (["n_groups" ])
6692+ self .d_inner = self .find_hparam (["mamba_d_ssm" ])
6693+ self .d_head = self .find_hparam (["d_head" ])
6694+
6695+ # Initialize any Falcon Mamba2 specific attributes
6696+ self .has_attention = True # Falcon Mamba2 has attention components
6697+
6698+ # Load Falcon-H1 multipliers from hyperparameters
6699+ self .attention_in_multiplier = self .find_hparam (["attention_in_multiplier" ], optional = True )
6700+ self .attention_out_multiplier = self .find_hparam (["attention_out_multiplier" ], optional = True )
6701+ self .ssm_in_multiplier = self .find_hparam (["ssm_in_multiplier" ], optional = True )
6702+ self .ssm_out_multiplier = self .find_hparam (["ssm_out_multiplier" ], optional = True )
6703+ self .mlp_multipliers = self .find_hparam (["mlp_multipliers" ], optional = True )
6704+ self .ssm_multipliers = self .find_hparam (["ssm_multipliers" ], optional = True )
6705+ self .intermediate_size = self .find_hparam (["intermediate_size" ])
6706+ self .key_multiplier = self .find_hparam (["key_multiplier" ], optional = True )
6707+
6708+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6709+ prefixed = []
6710+ for pfx in self .hparam_prefixes :
6711+ prefixed .extend (
6712+ "_" .join ([pfx , k ])
6713+ for k in keys
6714+ )
6715+ keys = list (keys ) + prefixed
6716+ return super ().find_hparam (keys , * args , ** kwargs )
6717+
6718+ def set_vocab (self ):
6719+ self ._set_vocab_gpt2 ()
6720+
6721+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6722+ tensors = list (super ().modify_tensors (data_torch , name , bid ))
6723+ tensor = tensors [0 ][1 ]
6724+
6725+ if "down_proj" in name :
6726+ tensor = tensor * self .mlp_multipliers [1 ]
6727+ elif "gate_proj" in name :
6728+ tensor = tensor * self .mlp_multipliers [0 ]
6729+ elif "k_proj" in name :
6730+ tensor = tensor * self .key_multiplier * self .attention_in_multiplier
6731+ elif "q_proj" in name :
6732+ tensor = tensor * self .attention_in_multiplier
6733+ elif "v_proj" in name :
6734+ tensor = tensor * self .attention_in_multiplier
6735+ elif "o_proj" in name :
6736+ tensor = tensor * self .attention_out_multiplier
6737+ elif "out_proj" in name :
6738+ tensor = tensor * self .ssm_out_multiplier
6739+ elif "in_proj" in name :
6740+ tensor = tensor * self .ssm_in_multiplier
6741+ zxbcdt_multipliers = self .hparams ["ssm_multipliers" ]
6742+ intermediate_size = self .hparams ["mamba_d_ssm" ]
6743+ groups_time_state_size = self .hparams ["mamba_n_groups" ] * self .hparams ["mamba_d_state" ]
6744+ tensor [:intermediate_size , :] *= zxbcdt_multipliers [0 ]
6745+ tensor [intermediate_size :2 * intermediate_size , :] *= zxbcdt_multipliers [1 ]
6746+ tensor [2 * intermediate_size :2 * intermediate_size + groups_time_state_size , :] *= zxbcdt_multipliers [2 ]
6747+ tensor [2 * intermediate_size + groups_time_state_size :2 * intermediate_size + 2 * groups_time_state_size , :] *= zxbcdt_multipliers [3 ]
6748+ tensor [2 * intermediate_size + 2 * groups_time_state_size :, :] *= zxbcdt_multipliers [4 ]
6749+ elif "lm_head" in name :
6750+ tensor = tensor * self .hparams ["lm_head_multiplier" ]
6751+ elif "embed_tokens" in name :
6752+ tensor = tensor * self .hparams ["embedding_multiplier" ]
6753+ elif "mamba.norm" in name :
6754+ tensor = tensor .reshape (self .n_group , self .d_inner // self .n_group )
6755+
6756+ tensors = [(tensors [0 ][0 ], tensor )]
6757+ return tensors
6758+
6759+ def set_gguf_parameters (self ):
6760+ super ().set_gguf_parameters ()
6761+
6762+ ## General Params ##
6763+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
6764+ # Override some Mamba2 defaults
6765+ self .gguf_writer .add_block_count (self .block_count )
6766+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
6767+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
6768+
6769+ ## Attention params ##
6770+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ]) # Override value 0 from Mamba2
6771+ self .gguf_writer .add_head_count_kv (self .hparams ["num_key_value_heads" ])
6772+ self .gguf_writer .add_key_length (self .hparams ["head_dim" ])
6773+ self .gguf_writer .add_value_length (self .hparams ["head_dim" ])
6774+
6775+ ## Validation ##
6776+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6777+ assert self .d_inner % self .d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { self .d_head } "
6778+
6779+ # Add any other Falcon Mamba2 specific configuration
6780+ self .gguf_writer .add_rope_freq_base (self .find_hparam (["rope_theta" ]))
6781+
6782+
66596783@ModelBase .register ("HunYuanMoEV1ForCausalLM" )
66606784class HunYuanMoEModel (TextModel ):
66616785 model_arch = gguf .MODEL_ARCH .HUNYUAN_MOE
@@ -6809,6 +6933,16 @@ def prepare_tensors(self):
68096933class SmolLM3Model (LlamaModel ):
68106934 model_arch = gguf .MODEL_ARCH .SMOLLM3
68116935
6936+ def set_vocab (self ):
6937+ super ().set_vocab ()
6938+ # remove unsupported array slicing in chat template
6939+ # ref: https://huggingface.co/ggml-org/SmolLM3-3B-GGUF/discussions/1
6940+ from transformers import AutoTokenizer
6941+ tokenizer = AutoTokenizer .from_pretrained (self .dir_model )
6942+ if tokenizer .chat_template is not None :
6943+ chat_template = tokenizer .chat_template .replace ("[:]" , "" )
6944+ self .gguf_writer .add_chat_template (chat_template )
6945+
68126946###### CONVERSION LOGIC ######
68136947
68146948
0 commit comments