@@ -735,6 +735,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
735735 if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c" :
736736 # ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
737737 res = "qwen2"
738+ if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273" :
739+ # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
740+ res = "grok-2"
738741 if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5" :
739742 # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
740743 res = "llama-bpe"
@@ -2682,57 +2685,109 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26822685 yield (new_name , data_torch )
26832686
26842687
2685- @ModelBase .register ("GrokForCausalLM" )
2688+ @ModelBase .register ("GrokForCausalLM" , "Grok1ForCausalLM" )
26862689class GrokModel (TextModel ):
26872690 model_arch = gguf .MODEL_ARCH .GROK
26882691
26892692 def set_vocab (self ):
2690- self ._set_vocab_sentencepiece ()
2693+ if (self .dir_model / 'tokenizer.model' ).is_file ():
2694+ self ._set_vocab_sentencepiece ()
2695+ return
2696+
2697+ if not (self .dir_model / 'tokenizer.json' ).is_file () or not (self .dir_model / 'chat_template.jinja' ).is_file ():
2698+ logger .error ('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer' )
2699+ sys .exit (1 )
2700+
2701+ self ._set_vocab_gpt2 ()
26912702
26922703 def __init__ (self , * args , ** kwargs ):
26932704 super ().__init__ (* args , ** kwargs )
26942705
26952706 def set_gguf_parameters (self ):
26962707 super ().set_gguf_parameters ()
26972708
2698- _experts : list [dict [str , Tensor ]] | None = None
2709+ self .gguf_writer .add_attn_logit_softcapping (self .hparams .get ("attn_logit_softcapping" , 30.0 ))
2710+ self .gguf_writer .add_router_logit_softcapping (self .hparams .get ("router_logit_softcapping" , 30.0 ))
2711+ if (final_logit_softcap := self .hparams .get ("final_logit_softcapping" )):
2712+ self .gguf_writer .add_final_logit_softcapping (final_logit_softcap )
2713+
2714+ if (rope_dim := self .hparams .get ("head_dim" )) is None :
2715+ rope_dim = self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
2716+
2717+ if (moe_intermediate_size := self .hparams .get ("moe_intermediate_size" )) is not None :
2718+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
2719+
2720+ # Treat "original" as "yarn", seems to have been a mistake
2721+ if self .hparams .get ("rope_type" ) in ("yarn" , "original" ):
2722+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
2723+ self .gguf_writer .add_rope_scaling_factor (self .hparams ["scaling_factor" ])
2724+ self .gguf_writer .add_rope_scaling_orig_ctx_len (self .hparams ["original_max_position_embeddings" ])
2725+ self .gguf_writer .add_rope_scaling_yarn_ext_factor (self .hparams ["extrapolation_factor" ])
2726+ self .gguf_writer .add_rope_scaling_yarn_attn_factor (self .hparams ["attn_factor" ])
2727+ self .gguf_writer .add_rope_scaling_yarn_beta_fast (self .hparams ["beta_fast" ])
2728+ self .gguf_writer .add_rope_scaling_yarn_beta_slow (self .hparams ["beta_slow" ])
2729+
2730+ if temp_len := self .hparams .get ("attn_temperature_len" ):
2731+ self .gguf_writer .add_attn_temperature_length (temp_len )
2732+
2733+ self .gguf_writer .add_attn_output_scale (self .hparams .get ("attn_output_multiplier" , rope_dim ** - 0.5 ))
2734+ self .gguf_writer .add_embedding_scale (self .hparams ["embedding_multiplier_scale" ])
2735+ self .gguf_writer .add_logit_scale (self .hparams ["output_multiplier_scale" ])
2736+
2737+ _experts : list [dict [str , list [Tensor ]]] | None = None
2738+ _cur_expert = ""
26992739
27002740 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2741+ tensors : list [tuple [str , Tensor ]] = []
2742+ is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
2743+
2744+ if not is_expert :
2745+ tensors .append ((self .map_tensor_name (name ), data_torch ))
2746+
27012747 # process the experts separately
2702- if name . find ( ".moe." ) != - 1 :
2748+ if is_expert or self . _cur_expert :
27032749 n_experts = self .hparams ["num_local_experts" ]
27042750
27052751 assert bid is not None
27062752
27072753 if self ._experts is None :
27082754 self ._experts = [{} for _ in range (self .block_count )]
27092755
2710- self ._experts [bid ][name ] = data_torch
2756+ # concatenate split tensors
2757+ if name in self ._experts [bid ]:
2758+ self ._cur_expert = name
2759+ self ._experts [bid ][name ].append (data_torch )
2760+ return []
2761+ elif is_expert :
2762+ self ._cur_expert = name
2763+ self ._experts [bid ][name ] = [data_torch ]
2764+ return []
2765+ else :
2766+ self ._cur_expert = ""
27112767
2712- if len (self ._experts [bid ]) >= n_experts * 3 :
2713- tensors : list [tuple [str , Tensor ]] = []
2768+ for bid in range (self .block_count ):
2769+ if len (self ._experts [bid ]) >= n_experts * 3 :
2770+ # merge the experts into a single 3d tensor
2771+ for wid in [("linear" , "w1" , 0 ), ("linear_1" , "w2" , 1 ), ("linear_v" , "w3" , 0 )]:
2772+ datas : list [Tensor ] = []
27142773
2715- # merge the experts into a single 3d tensor
2716- for wid in ["linear" , "linear_1" , "linear_v" ]:
2717- datas : list [Tensor ] = []
2774+ for xid in range (n_experts ):
2775+ ename = f"transformer.decoder_layer.{ bid } .moe.{ xid } .{ wid [0 ]} .weight"
2776+ if ename not in self ._experts [bid ]:
2777+ ename = f"model.layers.{ bid } .block_sparse_moe.experts.{ xid } .{ wid [1 ]} .weight"
2778+ tensor_list = self ._experts [bid ][ename ]
2779+ datas .append (torch .cat (tensor_list , dim = wid [2 ]) if len (tensor_list ) > 1 else tensor_list [0 ])
2780+ del self ._experts [bid ][ename ]
27182781
2719- for xid in range (n_experts ):
2720- ename = f"transformer.decoder_layer.{ bid } .moe.{ xid } .{ wid } .weight"
2721- datas .append (self ._experts [bid ][ename ])
2722- del self ._experts [bid ][ename ]
2782+ data_torch = torch .stack (datas , dim = 0 )
27232783
2724- data_torch = torch . stack ( datas , dim = 0 )
2784+ merged_name = f"transformer.decoder_layer. { bid } .moe. { wid [ 0 ] } .weight"
27252785
2726- merged_name = f"transformer.decoder_layer.{ bid } .moe.{ wid } .weight"
2727-
2728- new_name = self .map_tensor_name (merged_name )
2786+ new_name = self .map_tensor_name (merged_name )
27292787
2730- tensors .append ((new_name , data_torch ))
2731- return tensors
2732- else :
2733- return []
2788+ yield (new_name , data_torch )
27342789
2735- return [( self . map_tensor_name ( name ), data_torch )]
2790+ yield from tensors
27362791
27372792
27382793@ModelBase .register ("DbrxForCausalLM" )
0 commit comments