@@ -2230,6 +2230,7 @@ def prepare_tensors(self):
22302230 "MixtralForCausalLM" ,
22312231 "VLlama3ForCausalLM" ,
22322232 "LlavaForConditionalGeneration" ,
2233+ "VoxtralForConditionalGeneration" ,
22332234 "LlamaModel" )
22342235class LlamaModel (TextModel ):
22352236 model_arch = gguf .MODEL_ARCH .LLAMA
@@ -2242,6 +2243,11 @@ def __init__(self, *args, **kwargs):
22422243 self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
22432244
22442245 def set_vocab (self ):
2246+ path_tekken_json = self .dir_model / "tekken.json"
2247+ path_tokenizer_json = self .dir_model / "tokenizer.json"
2248+ if path_tekken_json .is_file () and not path_tokenizer_json .is_file ():
2249+ return self .set_vocab_tekken ()
2250+
22452251 try :
22462252 self ._set_vocab_sentencepiece ()
22472253 except FileNotFoundError :
@@ -2274,6 +2280,52 @@ def set_vocab(self):
22742280 if self .hparams .get ("vocab_size" , 32000 ) == 49152 :
22752281 self .gguf_writer .add_add_bos_token (False )
22762282
2283+ def set_vocab_tekken (self ):
2284+ vocab = gguf .vocab .MistralVocab (self .dir_model )
2285+ self .gguf_writer .add_tokenizer_model (vocab .gguf_tokenizer_model )
2286+
2287+ tokens = []
2288+ scores = []
2289+ toktypes = []
2290+
2291+ for text , score , toktype in vocab .all_tokens ():
2292+ tokens .append (text )
2293+ scores .append (score )
2294+ toktypes .append (toktype )
2295+
2296+ assert len (tokens ) == vocab .vocab_size , (
2297+ f"token count ({ len (tokens )} ) != vocab size ({ vocab .vocab_size } )"
2298+ )
2299+
2300+ if vocab .tokenizer_type == gguf .vocab .MistralTokenizerType .tekken :
2301+ self .gguf_writer .add_tokenizer_pre ("tekken" )
2302+ self .gguf_writer .add_token_merges (
2303+ vocab .extract_vocab_merges_from_model ()
2304+ )
2305+
2306+ logger .info (
2307+ f"Setting bos, eos, unk and pad token IDs to { vocab .bos_id } , { vocab .eos_id } , { vocab .unk_id } , { vocab .pad_id } ."
2308+ )
2309+
2310+ self .gguf_writer .add_bos_token_id (vocab .bos_id )
2311+ self .gguf_writer .add_eos_token_id (vocab .eos_id )
2312+ self .gguf_writer .add_unk_token_id (vocab .unk_id )
2313+ self .gguf_writer .add_pad_token_id (vocab .pad_id )
2314+
2315+ self .gguf_writer .add_token_list (tokens )
2316+ self .gguf_writer .add_token_scores (scores )
2317+ self .gguf_writer .add_token_types (toktypes )
2318+ self .gguf_writer .add_vocab_size (vocab .vocab_size )
2319+
2320+ self .gguf_writer .add_add_bos_token (True )
2321+ self .gguf_writer .add_add_eos_token (False )
2322+
2323+ script_dir = Path (__file__ ).parent
2324+ template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
2325+ with open (template_path , "r" , encoding = "utf-8" ) as f :
2326+ template = f .read ()
2327+ self .gguf_writer .add_chat_template (template )
2328+
22772329 def set_gguf_parameters (self ):
22782330 super ().set_gguf_parameters ()
22792331 hparams = self .hparams
@@ -2301,12 +2353,13 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
23012353 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
23022354 n_head = self .hparams ["num_attention_heads" ]
23032355 n_kv_head = self .hparams .get ("num_key_value_heads" )
2304- is_vision_tensor = "vision_tower" in name \
2356+ is_multimodal_tensor = "vision_tower" in name \
23052357 or "vision_model" in name \
2358+ or "audio_tower" in name \
23062359 or "model.connector" in name \
23072360 or "multi_modal_projector" in name
23082361
2309- if is_vision_tensor :
2362+ if is_multimodal_tensor :
23102363 return [] # skip vision tensors
23112364 elif self .hf_arch == "LlamaModel" :
23122365 name = "model." + name
@@ -7561,9 +7614,10 @@ class WhisperEncoderModel(MmprojModel):
75617614
75627615 def __init__ (self , * args , ** kwargs ):
75637616 super ().__init__ (* args , ** kwargs )
7564- self .hparams ["hidden_size" ] = self .hparams ["d_model" ]
7565- self .hparams ["intermediate_size" ] = self .hparams ["encoder_ffn_dim" ]
7566- self .hparams ["num_attention_heads" ] = self .hparams ["encoder_attention_heads" ]
7617+ if "hidden_size" not in self .hparams and "intermediate_size" not in self .hparams :
7618+ self .hparams ["hidden_size" ] = self .hparams ["d_model" ]
7619+ self .hparams ["intermediate_size" ] = self .hparams ["encoder_ffn_dim" ]
7620+ self .hparams ["num_attention_heads" ] = self .hparams ["encoder_attention_heads" ]
75677621
75687622 def set_gguf_parameters (self ):
75697623 super ().set_gguf_parameters ()
@@ -7602,9 +7656,21 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
76027656
76037657 def set_gguf_parameters (self ):
76047658 super ().set_gguf_parameters ()
7659+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .ULTRAVOX )
76057660 self .gguf_writer .add_audio_stack_factor (self .global_config ["stack_factor" ])
76067661
76077662
7663+ @ModelBase .register ("VoxtralForConditionalGeneration" )
7664+ class VoxtralWhisperEncoderModel (WhisperEncoderModel ):
7665+ has_vision_encoder = False # no vision encoder
7666+ has_audio_encoder = True
7667+
7668+ def set_gguf_parameters (self ):
7669+ super ().set_gguf_parameters ()
7670+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .VOXTRAL )
7671+ self .gguf_writer .add_audio_stack_factor (4 ) # == intermediate_size // hidden_size
7672+
7673+
76087674@ModelBase .register ("FalconH1ForCausalLM" )
76097675class FalconH1Model (Mamba2Model ):
76107676 model_arch = gguf .MODEL_ARCH .FALCON_H1
@@ -7919,6 +7985,88 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
79197985 return [(self .map_tensor_name (name ), data_torch )]
79207986
79217987
7988+ @ModelBase .register ("SmallThinkerForCausalLM" )
7989+ class SmallThinkerModel (TextModel ):
7990+ model_arch = gguf .MODEL_ARCH .SMALLTHINKER
7991+
7992+ def set_gguf_parameters (self ):
7993+ super ().set_gguf_parameters ()
7994+ if (n_experts := self .hparams .get ("num_experts" , self .hparams .get ("moe_num_primary_experts" ))) is not None :
7995+ self .gguf_writer .add_expert_count (n_experts )
7996+ if (n_experts_used := self .hparams .get ("num_experts_per_tok" , self .hparams .get ("moe_num_active_primary_experts" ))) is not None :
7997+ self .gguf_writer .add_expert_used_count (n_experts_used )
7998+ if (moe_intermediate_size := self .hparams .get ("moe_ffn_hidden_size" )) is not None :
7999+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
8000+ self .gguf_writer .add_feed_forward_length (moe_intermediate_size )
8001+ logger .info (f"gguf: expert feed forward length = { moe_intermediate_size } " )
8002+ if (self .hparams .get ('moe_primary_router_apply_softmax' )):
8003+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SOFTMAX )
8004+ else :
8005+ self .gguf_writer .add_expert_gating_func (gguf .ExpertGatingFuncType .SIGMOID )
8006+ # YaRN is not enabled by default
8007+ # To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
8008+ rope_scaling = self .hparams .get ("rope_scaling" ) or {}
8009+ if rope_scaling .get ("rope_type" , rope_scaling .get ("type" )) == "yarn" and "factor" in rope_scaling :
8010+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
8011+ self .gguf_writer .add_rope_scaling_factor (rope_scaling ["factor" ])
8012+ self .gguf_writer .add_rope_scaling_orig_ctx_len (rope_scaling ["original_max_position_embeddings" ])
8013+
8014+ sliding_window_layout = self .hparams .get ("sliding_window_layout" )
8015+ if sliding_window_layout :
8016+ for i in sliding_window_layout :
8017+ if i != 0 :
8018+ sliding_window = self .hparams .get ("sliding_window_size" )
8019+ if sliding_window :
8020+ self .gguf_writer .add_sliding_window (sliding_window )
8021+ break
8022+
8023+ _experts : list [dict [str , Tensor ]] | None = None
8024+
8025+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8026+ # process the experts separately
8027+ if name .find ("experts" ) != - 1 :
8028+ n_experts = self .hparams .get ("num_experts" , self .hparams .get ("moe_num_primary_experts" ))
8029+ assert bid is not None
8030+
8031+ if self ._experts is None :
8032+ self ._experts = [{} for _ in range (self .block_count )]
8033+
8034+ self ._experts [bid ][name ] = data_torch
8035+
8036+ if len (self ._experts [bid ]) >= n_experts * 3 :
8037+ tensors : list [tuple [str , Tensor ]] = []
8038+
8039+ # merge the experts into a single 3d tensor
8040+ for w_name in ["down" , "gate" , "up" ]:
8041+ datas : list [Tensor ] = []
8042+
8043+ for xid in range (n_experts ):
8044+ ename = f"model.layers.{ bid } .block_sparse_moe.experts.{ xid } .{ w_name } .weight"
8045+ datas .append (self ._experts [bid ][ename ])
8046+ del self ._experts [bid ][ename ]
8047+
8048+ data_torch = torch .stack (datas , dim = 0 )
8049+
8050+ merged_name = f"model.layers.{ bid } .block_sparse_moe.experts.{ w_name } .weight"
8051+
8052+ new_name = self .map_tensor_name (merged_name )
8053+
8054+ tensors .append ((new_name , data_torch ))
8055+ return tensors
8056+ else :
8057+ return []
8058+
8059+ return [(self .map_tensor_name (name ), data_torch )]
8060+
8061+ def prepare_tensors (self ):
8062+ super ().prepare_tensors ()
8063+
8064+ if self ._experts is not None :
8065+ # flatten `list[dict[str, Tensor]]` into `list[str]`
8066+ experts = [k for d in self ._experts for k in d .keys ()]
8067+ if len (experts ) > 0 :
8068+ raise ValueError (f"Unprocessed experts: { experts } " )
8069+
79228070###### CONVERSION LOGIC ######
79238071
79248072
0 commit comments