@@ -6260,10 +6260,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
62606260 SSM layers"""
62616261 model_arch = gguf .MODEL_ARCH .GRANITE_MOE_HYBRID
62626262
6263- def __init__ (self , * args , ** kwargs ):
6264- super ().__init__ (* args , ** kwargs )
6265- self ._transformer_model_class = GraniteMoeModel
6266-
62676263 def get_attn_layres (self ):
62686264 if layer_types := self .hparams .get ("layer_types" ):
62696265 return [
@@ -6275,51 +6271,16 @@ def get_attn_layres(self):
62756271 def modify_tensors (
62766272 self , data_torch : Tensor , name : str , bid : int | None
62776273 ) -> Iterable [tuple [str , Tensor ]]:
6278-
6279- # In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert
6280- if name .endswith ("block_sparse_moe.input_linear.weight" ):
6281- ffn_dim = self .hparams ["intermediate_size" ]
6282- assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * intermediate_size"
6283- gate , up = data_torch [..., :ffn_dim , :], data_torch [..., ffn_dim :, :]
6284- return [
6285- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_EXP , bid ), gate ),
6286- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_EXP , bid ), up ),
6287- ]
6288- if name .endswith ("shared_mlp.input_linear.weight" ):
6289- ffn_dim = self .hparams ["shared_intermediate_size" ]
6290- assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * shared_intermediate_size"
6291- gate , up = data_torch .split (ffn_dim , dim = - 2 )
6292- return [
6293- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6294- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6295- ]
6296-
6274+ if (
6275+ name .endswith ("block_sparse_moe.input_linear.weight" ) or
6276+ name .endswith ("shared_mlp.input_linear.weight" )
6277+ ):
6278+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
62976279 return super ().modify_tensors (data_torch , name , bid )
62986280
6299-
63006281 def set_gguf_parameters (self ):
6301- super ().set_gguf_parameters ()
6302- if attention_scale := self .hparams .get ("attention_multiplier" ):
6303- self .gguf_writer .add_attention_scale (attention_scale )
6304- logger .info ("gguf: (granite) attention_scale = %s" , attention_scale )
6305- if embedding_scale := self .hparams .get ("embedding_multiplier" ):
6306- self .gguf_writer .add_embedding_scale (embedding_scale )
6307- logger .info ("gguf: (granite) embedding_scale = %s" , embedding_scale )
6308- if residual_scale := self .hparams .get ("residual_multiplier" ):
6309- self .gguf_writer .add_residual_scale (residual_scale )
6310- logger .info ("gguf: (granite) residual_scale = %s" , residual_scale )
6311- if logits_scale := self .hparams .get ("logits_scaling" ):
6312- self .gguf_writer .add_logit_scale (logits_scale )
6313- logger .info ("gguf: (granite) logits_scale = %s" , logits_scale )
6314- if shared_feed_forward_length := self .hparams .get ("shared_intermediate_size" ):
6315- self .gguf_writer .add_expert_shared_feed_forward_length (shared_feed_forward_length )
6316- logger .info ("gguf: (granitemoeshared) shared_feed_forward_length = %s" , shared_feed_forward_length )
6317- if (n_experts := self .hparams .get ("num_local_experts" )) is not None :
6318- self .gguf_writer .add_expert_count (n_experts )
6319- logger .info (f"gguf: expert count = { n_experts } " )
6320- if (n_experts_used := self .hparams .get ("num_experts_per_tok" )) is not None :
6321- self .gguf_writer .add_expert_used_count (n_experts_used )
6322- logger .info (f"gguf: experts used count = { n_experts_used } " )
6282+ GraniteMoeModel .set_gguf_parameters (self )
6283+ BambaModel .set_gguf_parameters (self )
63236284
63246285 def set_vocab (self ):
63256286 self .hparams ["pad_vocab_size_multiple" ] = 8
0 commit comments