@@ -735,6 +735,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
735735 if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c" :
736736 # ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
737737 res = "qwen2"
738+ if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273" :
739+ # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
740+ res = "grok-2"
738741 if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5" :
739742 # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
740743 res = "llama-bpe"
@@ -2685,57 +2688,109 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
26852688 yield (new_name , data_torch )
26862689
26872690
2688- @ModelBase .register ("GrokForCausalLM" )
2691+ @ModelBase .register ("GrokForCausalLM" , "Grok1ForCausalLM" )
26892692class GrokModel (TextModel ):
26902693 model_arch = gguf .MODEL_ARCH .GROK
26912694
26922695 def set_vocab (self ):
2693- self ._set_vocab_sentencepiece ()
2696+ if (self .dir_model / 'tokenizer.model' ).is_file ():
2697+ self ._set_vocab_sentencepiece ()
2698+ return
2699+
2700+ if not (self .dir_model / 'tokenizer.json' ).is_file () or not (self .dir_model / 'chat_template.jinja' ).is_file ():
2701+ logger .error ('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer' )
2702+ sys .exit (1 )
2703+
2704+ self ._set_vocab_gpt2 ()
26942705
26952706 def __init__ (self , * args , ** kwargs ):
26962707 super ().__init__ (* args , ** kwargs )
26972708
26982709 def set_gguf_parameters (self ):
26992710 super ().set_gguf_parameters ()
27002711
2701- _experts : list [dict [str , Tensor ]] | None = None
2712+ self .gguf_writer .add_attn_logit_softcapping (self .hparams .get ("attn_logit_softcapping" , 30.0 ))
2713+ self .gguf_writer .add_router_logit_softcapping (self .hparams .get ("router_logit_softcapping" , 30.0 ))
2714+ if (final_logit_softcap := self .hparams .get ("final_logit_softcapping" )):
2715+ self .gguf_writer .add_final_logit_softcapping (final_logit_softcap )
2716+
2717+ if (rope_dim := self .hparams .get ("head_dim" )) is None :
2718+ rope_dim = self .hparams ["hidden_size" ] // self .hparams ["num_attention_heads" ]
2719+
2720+ if (moe_intermediate_size := self .hparams .get ("moe_intermediate_size" )) is not None :
2721+ self .gguf_writer .add_expert_feed_forward_length (moe_intermediate_size )
2722+
2723+ # Treat "original" as "yarn", seems to have been a mistake
2724+ if self .hparams .get ("rope_type" ) in ("yarn" , "original" ):
2725+ self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
2726+ self .gguf_writer .add_rope_scaling_factor (self .hparams ["scaling_factor" ])
2727+ self .gguf_writer .add_rope_scaling_orig_ctx_len (self .hparams ["original_max_position_embeddings" ])
2728+ self .gguf_writer .add_rope_scaling_yarn_ext_factor (self .hparams ["extrapolation_factor" ])
2729+ self .gguf_writer .add_rope_scaling_yarn_attn_factor (self .hparams ["attn_factor" ])
2730+ self .gguf_writer .add_rope_scaling_yarn_beta_fast (self .hparams ["beta_fast" ])
2731+ self .gguf_writer .add_rope_scaling_yarn_beta_slow (self .hparams ["beta_slow" ])
2732+
2733+ if temp_len := self .hparams .get ("attn_temperature_len" ):
2734+ self .gguf_writer .add_attn_temperature_length (temp_len )
2735+
2736+ self .gguf_writer .add_attn_output_scale (self .hparams .get ("attn_output_multiplier" , rope_dim ** - 0.5 ))
2737+ self .gguf_writer .add_embedding_scale (self .hparams ["embedding_multiplier_scale" ])
2738+ self .gguf_writer .add_logit_scale (self .hparams ["output_multiplier_scale" ])
2739+
2740+ _experts : list [dict [str , list [Tensor ]]] | None = None
2741+ _cur_expert = ""
27022742
27032743 def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2744+ tensors : list [tuple [str , Tensor ]] = []
2745+ is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
2746+
2747+ if not is_expert :
2748+ tensors .append ((self .map_tensor_name (name ), data_torch ))
2749+
27042750 # process the experts separately
2705- if name . find ( ".moe." ) != - 1 :
2751+ if is_expert or self . _cur_expert :
27062752 n_experts = self .hparams ["num_local_experts" ]
27072753
27082754 assert bid is not None
27092755
27102756 if self ._experts is None :
27112757 self ._experts = [{} for _ in range (self .block_count )]
27122758
2713- self ._experts [bid ][name ] = data_torch
2759+ # concatenate split tensors
2760+ if name in self ._experts [bid ]:
2761+ self ._cur_expert = name
2762+ self ._experts [bid ][name ].append (data_torch )
2763+ return []
2764+ elif is_expert :
2765+ self ._cur_expert = name
2766+ self ._experts [bid ][name ] = [data_torch ]
2767+ return []
2768+ else :
2769+ self ._cur_expert = ""
27142770
2715- if len (self ._experts [bid ]) >= n_experts * 3 :
2716- tensors : list [tuple [str , Tensor ]] = []
2771+ for bid in range (self .block_count ):
2772+ if len (self ._experts [bid ]) >= n_experts * 3 :
2773+ # merge the experts into a single 3d tensor
2774+ for wid in [("linear" , "w1" , 0 ), ("linear_1" , "w2" , 1 ), ("linear_v" , "w3" , 0 )]:
2775+ datas : list [Tensor ] = []
27172776
2718- # merge the experts into a single 3d tensor
2719- for wid in ["linear" , "linear_1" , "linear_v" ]:
2720- datas : list [Tensor ] = []
2777+ for xid in range (n_experts ):
2778+ ename = f"transformer.decoder_layer.{ bid } .moe.{ xid } .{ wid [0 ]} .weight"
2779+ if ename not in self ._experts [bid ]:
2780+ ename = f"model.layers.{ bid } .block_sparse_moe.experts.{ xid } .{ wid [1 ]} .weight"
2781+ tensor_list = self ._experts [bid ][ename ]
2782+ datas .append (torch .cat (tensor_list , dim = wid [2 ]) if len (tensor_list ) > 1 else tensor_list [0 ])
2783+ del self ._experts [bid ][ename ]
27212784
2722- for xid in range (n_experts ):
2723- ename = f"transformer.decoder_layer.{ bid } .moe.{ xid } .{ wid } .weight"
2724- datas .append (self ._experts [bid ][ename ])
2725- del self ._experts [bid ][ename ]
2785+ data_torch = torch .stack (datas , dim = 0 )
27262786
2727- data_torch = torch . stack ( datas , dim = 0 )
2787+ merged_name = f"transformer.decoder_layer. { bid } .moe. { wid [ 0 ] } .weight"
27282788
2729- merged_name = f"transformer.decoder_layer.{ bid } .moe.{ wid } .weight"
2730-
2731- new_name = self .map_tensor_name (merged_name )
2789+ new_name = self .map_tensor_name (merged_name )
27322790
2733- tensors .append ((new_name , data_torch ))
2734- return tensors
2735- else :
2736- return []
2791+ yield (new_name , data_torch )
27372792
2738- return [( self . map_tensor_name ( name ), data_torch )]
2793+ yield from tensors
27392794
27402795
27412796@ModelBase .register ("DbrxForCausalLM" )
0 commit comments