@@ -6055,10 +6055,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
60556055 SSM layers"""
60566056 model_arch = gguf .MODEL_ARCH .GRANITE_MOE_HYBRID
60576057
6058- def __init__ (self , * args , ** kwargs ):
6059- super ().__init__ (* args , ** kwargs )
6060- self ._transformer_model_class = GraniteMoeModel
6061-
60626058 def get_attn_layres (self ):
60636059 if layer_types := self .hparams .get ("layer_types" ):
60646060 return [
@@ -6070,51 +6066,16 @@ def get_attn_layres(self):
60706066 def modify_tensors (
60716067 self , data_torch : Tensor , name : str , bid : int | None
60726068 ) -> Iterable [tuple [str , Tensor ]]:
6073-
6074- # In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert
6075- if name .endswith ("block_sparse_moe.input_linear.weight" ):
6076- ffn_dim = self .hparams ["intermediate_size" ]
6077- assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * intermediate_size"
6078- gate , up = data_torch [..., :ffn_dim , :], data_torch [..., ffn_dim :, :]
6079- return [
6080- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_EXP , bid ), gate ),
6081- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_EXP , bid ), up ),
6082- ]
6083- if name .endswith ("shared_mlp.input_linear.weight" ):
6084- ffn_dim = self .hparams ["shared_intermediate_size" ]
6085- assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * shared_intermediate_size"
6086- gate , up = data_torch .split (ffn_dim , dim = - 2 )
6087- return [
6088- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6089- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6090- ]
6091-
6069+ if (
6070+ name .endswith ("block_sparse_moe.input_linear.weight" ) or
6071+ name .endswith ("shared_mlp.input_linear.weight" )
6072+ ):
6073+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
60926074 return super ().modify_tensors (data_torch , name , bid )
60936075
6094-
60956076 def set_gguf_parameters (self ):
6096- super ().set_gguf_parameters ()
6097- if attention_scale := self .hparams .get ("attention_multiplier" ):
6098- self .gguf_writer .add_attention_scale (attention_scale )
6099- logger .info ("gguf: (granite) attention_scale = %s" , attention_scale )
6100- if embedding_scale := self .hparams .get ("embedding_multiplier" ):
6101- self .gguf_writer .add_embedding_scale (embedding_scale )
6102- logger .info ("gguf: (granite) embedding_scale = %s" , embedding_scale )
6103- if residual_scale := self .hparams .get ("residual_multiplier" ):
6104- self .gguf_writer .add_residual_scale (residual_scale )
6105- logger .info ("gguf: (granite) residual_scale = %s" , residual_scale )
6106- if logits_scale := self .hparams .get ("logits_scaling" ):
6107- self .gguf_writer .add_logit_scale (logits_scale )
6108- logger .info ("gguf: (granite) logits_scale = %s" , logits_scale )
6109- if shared_feed_forward_length := self .hparams .get ("shared_intermediate_size" ):
6110- self .gguf_writer .add_expert_shared_feed_forward_length (shared_feed_forward_length )
6111- logger .info ("gguf: (granitemoeshared) shared_feed_forward_length = %s" , shared_feed_forward_length )
6112- if (n_experts := self .hparams .get ("num_local_experts" )) is not None :
6113- self .gguf_writer .add_expert_count (n_experts )
6114- logger .info (f"gguf: expert count = { n_experts } " )
6115- if (n_experts_used := self .hparams .get ("num_experts_per_tok" )) is not None :
6116- self .gguf_writer .add_expert_used_count (n_experts_used )
6117- logger .info (f"gguf: experts used count = { n_experts_used } " )
6077+ GraniteMoeModel .set_gguf_parameters (self )
6078+ BambaModel .set_gguf_parameters (self )
61186079
61196080 def set_vocab (self ):
61206081 self .hparams ["pad_vocab_size_multiple" ] = 8
0 commit comments