@@ -908,6 +908,40 @@ def _set_vocab_llama_hf(self):
908908 special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
909909 special_vocab .add_to_gguf (self .gguf_writer )
910910
911+ def _set_vocab_rwkv_world (self ):
912+ assert (self .dir_model / "rwkv_vocab_v20230424.txt" ).is_file ()
913+ vocab_size = self .hparams .get ("vocab_size" , 65536 )
914+
915+ tokens : list [bytes ] = ['<s>' .encode ("utf-8" )]
916+ toktypes : list [int ] = [gguf .TokenType .CONTROL ]
917+
918+ with open (self .dir_model / "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as f :
919+ lines = f .readlines ()
920+ for line in lines :
921+ parts = line .split (' ' )
922+ assert len (parts ) >= 3
923+ token , token_len = ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
924+ token = token .encode ("utf-8" ) if isinstance (token , str ) else token
925+ assert isinstance (token , bytes )
926+ assert len (token ) == token_len
927+ token_text : str = repr (token )[2 :- 1 ] # "b'\xff'" -> "\xff"
928+ tokens .append (token_text .encode ("utf-8" ))
929+ toktypes .append (gguf .TokenType .NORMAL )
930+ remainder = vocab_size - len (tokens )
931+ assert remainder >= 0
932+ for i in range (len (tokens ), vocab_size ):
933+ tokens .append (f"[PAD{ i } ]" .encode ("utf-8" ))
934+ toktypes .append (gguf .TokenType .UNUSED )
935+
936+ self .gguf_writer .add_tokenizer_model ("rwkv" )
937+ self .gguf_writer .add_token_list (tokens )
938+ self .gguf_writer .add_token_types (toktypes )
939+ special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
940+ special_vocab .chat_template = "rwkv-world"
941+ # hack: Add '\n\n' as the EOT token to make it chat normally
942+ special_vocab ._set_special_token ("eot" , 261 )
943+ special_vocab .add_to_gguf (self .gguf_writer )
944+
911945 def _set_vocab_builtin (self , model_name : Literal ["gpt-neox" , "llama-spm" ], vocab_size : int ):
912946 tokenizer_path = Path (sys .path [0 ]) / "models" / f"ggml-vocab-{ model_name } .gguf"
913947 logger .warning (f"Using tokenizer from '{ os .path .relpath (tokenizer_path , os .getcwd ())} '" )
@@ -1713,6 +1747,25 @@ def prepare_tensors(self):
17131747 raise ValueError (f"Unprocessed experts: { experts } " )
17141748
17151749
1750+ @Model .register ("Mistral3ForConditionalGeneration" )
1751+ class Mistral3Model (LlamaModel ):
1752+ model_arch = gguf .MODEL_ARCH .LLAMA
1753+
1754+ # we need to merge the text_config into the root level of hparams
1755+ def __init__ (self , * args , ** kwargs ):
1756+ hparams = Model .load_hparams (kwargs ["dir_model" ])
1757+ if "text_config" in hparams :
1758+ hparams = {** hparams , ** hparams ["text_config" ]}
1759+ kwargs ["hparams" ] = hparams
1760+ super ().__init__ (* args , ** kwargs )
1761+
1762+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ):
1763+ name = name .replace ("language_model." , "" )
1764+ if "multi_modal_projector" in name or "vision_tower" in name :
1765+ return []
1766+ return super ().modify_tensors (data_torch , name , bid )
1767+
1768+
17161769@Model .register ("DeciLMForCausalLM" )
17171770class DeciModel (Model ):
17181771 model_arch = gguf .MODEL_ARCH .DECI
@@ -3412,38 +3465,7 @@ class Rwkv6Model(Model):
34123465 model_arch = gguf .MODEL_ARCH .RWKV6
34133466
34143467 def set_vocab (self ):
3415- assert (self .dir_model / "rwkv_vocab_v20230424.txt" ).is_file ()
3416- vocab_size = self .hparams .get ("vocab_size" , 65536 )
3417-
3418- tokens : list [bytes ] = ['<s>' .encode ("utf-8" )]
3419- toktypes : list [int ] = [gguf .TokenType .CONTROL ]
3420-
3421- with open (self .dir_model / "rwkv_vocab_v20230424.txt" , "r" , encoding = "utf-8" ) as f :
3422- lines = f .readlines ()
3423- for line in lines :
3424- parts = line .split (' ' )
3425- assert len (parts ) >= 3
3426- token , token_len = ast .literal_eval (' ' .join (parts [1 :- 1 ])), int (parts [- 1 ])
3427- token = token .encode ("utf-8" ) if isinstance (token , str ) else token
3428- assert isinstance (token , bytes )
3429- assert len (token ) == token_len
3430- token_text : str = repr (token )[2 :- 1 ] # "b'\xff'" -> "\xff"
3431- tokens .append (token_text .encode ("utf-8" ))
3432- toktypes .append (gguf .TokenType .NORMAL )
3433- remainder = vocab_size - len (tokens )
3434- assert remainder >= 0
3435- for i in range (len (tokens ), vocab_size ):
3436- tokens .append (f"[PAD{ i } ]" .encode ("utf-8" ))
3437- toktypes .append (gguf .TokenType .UNUSED )
3438-
3439- self .gguf_writer .add_tokenizer_model ("rwkv" )
3440- self .gguf_writer .add_token_list (tokens )
3441- self .gguf_writer .add_token_types (toktypes )
3442- special_vocab = gguf .SpecialVocab (self .dir_model , load_merges = False )
3443- special_vocab .chat_template = "rwkv-world"
3444- # hack: Add '\n\n' as the EOT token to make it chat normally
3445- special_vocab ._set_special_token ("eot" , 261 )
3446- special_vocab .add_to_gguf (self .gguf_writer )
3468+ self ._set_vocab_rwkv_world ()
34473469
34483470 def set_gguf_parameters (self ):
34493471 block_count = self .hparams ["num_hidden_layers" ]
@@ -3565,6 +3587,168 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35653587 yield (new_name , data )
35663588
35673589
3590+ @Model .register ("Rwkv7ForCausalLM" , "RWKV7ForCausalLM" )
3591+ class Rwkv7Model (Model ):
3592+ model_arch = gguf .MODEL_ARCH .RWKV7
3593+
3594+ def set_vocab (self ):
3595+ self ._set_vocab_rwkv_world ()
3596+
3597+ def calc_lora_rank (self , hidden_size , exponent , multiplier ):
3598+ return max (1 , round (hidden_size ** exponent * multiplier / 32 )) * 32
3599+
3600+ def set_gguf_parameters (self ):
3601+ block_count = self .hparams ["num_hidden_layers" ]
3602+ try :
3603+ head_size = self .hparams ["head_size" ]
3604+ layer_norm_eps = self .hparams ["layer_norm_epsilon" ]
3605+ except KeyError :
3606+ head_size = self .hparams ["head_dim" ]
3607+ layer_norm_eps = self .hparams ["norm_eps" ]
3608+ hidden_size = self .hparams ["hidden_size" ]
3609+ intermediate_size = self .hparams ["intermediate_size" ] if self .hparams ["intermediate_size" ] is not None else (hidden_size * 4 )
3610+
3611+ # ICLR: In-Context-Learning-Rate
3612+ try :
3613+ lora_rank_decay = self .hparams ["lora_rank_decay" ] if self .hparams ["lora_rank_decay" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3614+ lora_rank_iclr = self .hparams ["lora_rank_iclr" ] if self .hparams ["lora_rank_iclr" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3615+ lora_rank_value_residual_mix = self .hparams ["lora_rank_value_residual_mix" ] if self .hparams ["lora_rank_value_residual_mix" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3616+ lora_rank_gate = self .hparams ["lora_rank_gate" ] if self .hparams ["lora_rank_gate" ] is not None else self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3617+ except KeyError :
3618+ lora_rank_decay = self .hparams ["decay_low_rank_dim" ] if self .hparams ["decay_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3619+ lora_rank_iclr = self .hparams ["a_low_rank_dim" ] if self .hparams ["a_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.8 )
3620+ lora_rank_value_residual_mix = self .hparams ["v_low_rank_dim" ] if self .hparams ["v_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.5 , 1.3 )
3621+ lora_rank_gate = self .hparams ["gate_low_rank_dim" ] if self .hparams ["gate_low_rank_dim" ] is not None else self .calc_lora_rank (hidden_size , 0.8 , 0.6 )
3622+
3623+ # RWKV isn't context limited
3624+ self .gguf_writer .add_context_length (1048576 )
3625+ self .gguf_writer .add_embedding_length (hidden_size )
3626+ self .gguf_writer .add_block_count (block_count )
3627+ self .gguf_writer .add_layer_norm_eps (layer_norm_eps )
3628+ self .gguf_writer .add_wkv_head_size (head_size )
3629+ self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3630+ self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3631+ self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3632+ self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3633+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3634+ self .gguf_writer .add_file_type (self .ftype )
3635+
3636+ # required by llama.cpp, unused
3637+ self .gguf_writer .add_head_count (0 )
3638+
3639+ lerp_weights : dict [int , dict [str , Tensor ]] = {}
3640+ lora_needs_transpose : bool = True
3641+
3642+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
3643+ # unify tensor names here to make life easier
3644+ name = name .replace ("blocks" , "layers" ).replace ("ffn" , "feed_forward" )
3645+ name = name .replace ("self_attn" , "attention" ).replace ("attn" , "attention" )
3646+ name = name .replace ("time_mixer." , "" )
3647+ # lora layer names in fla-hub's impl
3648+ if "_lora.lora" in name :
3649+ self .lora_needs_transpose = False
3650+ name = name .replace ("_lora.lora.0.weight" , "1.weight" )
3651+ name = name .replace ("_lora.lora.2.weight" , "2.weight" )
3652+ name = name .replace ("_lora.lora.2.bias" , "0.weight" )
3653+
3654+ name = name .replace ("feed_forward_norm" , "ln2" )
3655+ name = name .replace ("g_norm" , "ln_x" )
3656+
3657+ if "attention.v" in name and "value" not in self .map_tensor_name (name ) and bid == 0 :
3658+ # some models have dummy v0/v1/v2 on first layer while others don't
3659+ # ignore them all since they are not used
3660+ return
3661+
3662+ wkv_has_gate = self .hparams .get ("wkv_has_gate" , True )
3663+ lerp_list = ["r" , "w" , "k" , "v" , "a" , "g" ] if wkv_has_gate else ["r" , "w" , "k" , "v" , "a" ]
3664+
3665+ if bid is not None and "attention.x_" in name :
3666+ if "attention.x_x" in name :
3667+ # already concatenated
3668+ new_name = f"blk.{ bid } .time_mix_lerp_fused.weight"
3669+ data = data_torch .reshape (len (lerp_list ), 1 , 1 , - 1 )
3670+ yield (new_name , data )
3671+ else :
3672+ try :
3673+ self .lerp_weights [bid ][name ] = data_torch
3674+ except KeyError :
3675+ self .lerp_weights [bid ] = {name : data_torch }
3676+ if all (f"model.layers.{ bid } .attention.x_{ i } " in self .lerp_weights [bid ].keys () for i in lerp_list ):
3677+ new_name = f"blk.{ bid } .time_mix_lerp_fused.weight"
3678+ data = torch .stack ([self .lerp_weights [bid ][f"model.layers.{ bid } .attention.x_{ i } " ] for i in lerp_list ], dim = 0 )
3679+ yield (new_name , data )
3680+ return
3681+ else :
3682+ data_torch = data_torch .squeeze ()
3683+ new_name = self .map_tensor_name (name )
3684+
3685+ if not (new_name .endswith (".weight" ) or new_name .endswith (".bias" )):
3686+ new_name += ".weight"
3687+
3688+ if self .lora_needs_transpose and any (
3689+ new_name .endswith (t ) for t in [
3690+ "time_mix_w1.weight" , "time_mix_w2.weight" ,
3691+ "time_mix_a1.weight" , "time_mix_a2.weight" ,
3692+ "time_mix_v1.weight" , "time_mix_v2.weight" ,
3693+ "time_mix_g1.weight" , "time_mix_g2.weight" ,
3694+ ]
3695+ ):
3696+ data_torch = data_torch .transpose (0 , 1 )
3697+
3698+ if 'r_k' in new_name :
3699+ data_torch = data_torch .flatten ()
3700+
3701+ if bid == 0 and "time_mix_a" in new_name :
3702+ # dummy v0/v1/v2 on first layer
3703+ # easist way to make llama happy
3704+ yield (new_name .replace ("time_mix_a" , "time_mix_v" ), data_torch )
3705+
3706+ yield (new_name , data_torch )
3707+
3708+
3709+ @Model .register ("RwkvHybridForCausalLM" )
3710+ class ARwkv7Model (Rwkv7Model ):
3711+ model_arch = gguf .MODEL_ARCH .ARWKV7
3712+
3713+ def set_vocab (self ):
3714+ try :
3715+ self ._set_vocab_sentencepiece ()
3716+ except FileNotFoundError :
3717+ self ._set_vocab_gpt2 ()
3718+
3719+ def set_gguf_parameters (self ):
3720+ block_count = self .hparams ["num_hidden_layers" ]
3721+ hidden_size = self .hparams ["hidden_size" ]
3722+ head_size = self .hparams ["head_size" ]
3723+ rms_norm_eps = self .hparams ["rms_norm_eps" ]
3724+ intermediate_size = self .hparams ["intermediate_size" ]
3725+ wkv_has_gate = self .hparams ["wkv_has_gate" ]
3726+ assert self .hparams ["wkv_version" ] == 7
3727+
3728+ # ICLR: In-Context-Learning-Rate
3729+ lora_rank_decay = 64
3730+ lora_rank_iclr = 64
3731+ lora_rank_value_residual_mix = 32
3732+ lora_rank_gate = 128 if wkv_has_gate else 0
3733+
3734+ # RWKV isn't context limited
3735+ self .gguf_writer .add_context_length (1048576 )
3736+ self .gguf_writer .add_embedding_length (hidden_size )
3737+ self .gguf_writer .add_block_count (block_count )
3738+ self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
3739+ self .gguf_writer .add_wkv_head_size (head_size )
3740+ self .gguf_writer .add_decay_lora_rank (lora_rank_decay )
3741+ self .gguf_writer .add_iclr_lora_rank (lora_rank_iclr )
3742+ self .gguf_writer .add_value_residual_mix_lora_rank (lora_rank_value_residual_mix )
3743+ self .gguf_writer .add_gate_lora_rank (lora_rank_gate )
3744+ self .gguf_writer .add_feed_forward_length (intermediate_size )
3745+ self .gguf_writer .add_file_type (self .ftype )
3746+ self .gguf_writer .add_token_shift_count (1 )
3747+
3748+ # required by llama.cpp, unused
3749+ self .gguf_writer .add_head_count (0 )
3750+
3751+
35683752@Model .register ("MambaForCausalLM" , "MambaLMHeadModel" , "FalconMambaForCausalLM" )
35693753class MambaModel (Model ):
35703754 model_arch = gguf .MODEL_ARCH .MAMBA
0 commit comments