@@ -2635,19 +2635,82 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26352635 yield (new_name , data_torch )
26362636
26372637
2638- @ModelBase .register ("GrokForCausalLM" )
2638+ @ModelBase .register ("GrokForCausalLM" , "Grok1ForCausalLM" )
26392639class GrokModel (TextModel ):
26402640 model_arch = gguf .MODEL_ARCH .GROK
26412641
26422642 def set_vocab (self ):
2643- self ._set_vocab_sentencepiece ()
2643+ if (self .dir_model / 'tokenizer.model' ).is_file ():
2644+ self ._set_vocab_sentencepiece ()
2645+ return
2646+
2647+ tokenizer_path = self .dir_model / 'tokenizer.tok.json'
2648+ with open (tokenizer_path , "r" , encoding = "utf-8" ) as f :
2649+ tokenizer = json .load (f )
2650+
2651+ vocab_size = tokenizer ["vocab_size" ]
2652+ tokens : list [bytes ] = [f"[PAD{ i } ]" .encode ("utf-8" ) for i in range (vocab_size )]
2653+ scores : list [float ] = [- 10000.0 ] * vocab_size
2654+ toktypes : list [int ] = [gguf .TokenType .UNUSED ] * vocab_size
2655+
2656+ def decode_grok_token (token : dict , toktype : gguf .TokenType ) -> tuple [gguf .TokenType , int , str ]:
2657+ tokid = token ["token" ]
2658+ tokb = token ["bytes" ]
2659+ try :
2660+ tokc = bytes (tokb ).decode ("utf-8" )
2661+ except :
2662+ tokc = None
2663+ if len (tokb ) == 1 or not tokc :
2664+ return gguf .TokenType .BYTE , tokid , "<0x{:02X}>" .format (tokb [0 ])
2665+ else :
2666+ return toktype , tokid , tokc
2667+
2668+ for token in tokenizer ["special_tokens" ]:
2669+ toktype , tokid , tokc = decode_grok_token (token , gguf .TokenType .CONTROL )
2670+ tokens [tokid ] = tokc
2671+ toktypes [tokid ] = toktype
2672+ scores [tokid ] = 0.0
2673+
2674+ score = - 0.0
2675+ for token in tokenizer ["regular_tokens" ]:
2676+ toktype , tokid , tokc = decode_grok_token (token , gguf .TokenType .NORMAL )
2677+ tokens [tokid ] = tokc
2678+ toktypes [tokid ] = toktype
2679+ scores [tokid ] = score
2680+ score -= 1.0
2681+
2682+ self .gguf_writer .add_tokenizer_model ("llama" )
2683+ self .gguf_writer .add_tokenizer_pre ("default" )
2684+ self .gguf_writer .add_token_list (tokens )
2685+ self .gguf_writer .add_token_scores (scores )
2686+ self .gguf_writer .add_token_types (toktypes )
2687+
2688+ self .gguf_writer .add_add_bos_token (False )
2689+
2690+ special_vocab = gguf .SpecialVocab (self .dir_model , n_vocab = len (tokens ))
2691+ special_vocab .special_token_ids ["pad" ] = 0
2692+ special_vocab .special_token_ids ["sep" ] = 1
2693+ special_vocab .special_token_ids ["eos" ] = 2
2694+ special_vocab .add_to_gguf (self .gguf_writer )
26442695
26452696 def __init__ (self , * args , ** kwargs ):
26462697 super ().__init__ (* args , ** kwargs )
26472698
26482699 def set_gguf_parameters (self ):
26492700 super ().set_gguf_parameters ()
26502701
2702+ self .gguf_writer .add_attn_logit_softcapping (self .hparams .get ("attn_logit_softcapping" , 30.0 ))
2703+ self .gguf_writer .add_router_logit_softcapping (self .hparams .get ("router_logit_softcapping" , 30.0 ))
2704+ if (final_logit_softcap := self .hparams .get ("final_logit_softcapping" )):
2705+ self .gguf_writer .add_final_logit_softcapping (final_logit_softcap )
2706+
2707+ if (rope_dim := self .hparams .get ("head_dim" )) is None :
2708+ rope_dim = self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
2709+
2710+ self .gguf_writer .add_attn_output_scale (self .hparams .get ("attn_output_multiplier" , rope_dim ** - 0.5 ))
2711+ self .gguf_writer .add_embedding_scale (self .hparams ["embedding_multiplier_scale" ])
2712+ self .gguf_writer .add_logit_scale (self .hparams ["output_multiplier_scale" ])
2713+
26512714 _experts : list [dict [str , Tensor ]] | None = None
26522715
26532716 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
0 commit comments