@@ -6054,10 +6054,6 @@ class GraniteMoeHybridModel(BambaModel, GraniteMoeModel):
60546054 SSM layers"""
60556055 model_arch = gguf .MODEL_ARCH .GRANITE_MOE_HYBRID
60566056
6057- def __init__ (self , * args , ** kwargs ):
6058- super ().__init__ (* args , ** kwargs )
6059- self ._transformer_model_class = GraniteMoeModel
6060-
60616057 def get_attn_layres (self ):
60626058 if layer_types := self .hparams .get ("layer_types" ):
60636059 return [
@@ -6069,51 +6065,16 @@ def get_attn_layres(self):
60696065 def modify_tensors (
60706066 self , data_torch : Tensor , name : str , bid : int | None
60716067 ) -> Iterable [tuple [str , Tensor ]]:
6072-
6073- # In GraniteMoeHybrid, the mamba layers also have an MoE + Shared expert
6074- if name .endswith ("block_sparse_moe.input_linear.weight" ):
6075- ffn_dim = self .hparams ["intermediate_size" ]
6076- assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * intermediate_size"
6077- gate , up = data_torch [..., :ffn_dim , :], data_torch [..., ffn_dim :, :]
6078- return [
6079- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_EXP , bid ), gate ),
6080- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_EXP , bid ), up ),
6081- ]
6082- if name .endswith ("shared_mlp.input_linear.weight" ):
6083- ffn_dim = self .hparams ["shared_intermediate_size" ]
6084- assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * shared_intermediate_size"
6085- gate , up = data_torch .split (ffn_dim , dim = - 2 )
6086- return [
6087- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6088- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6089- ]
6090-
6068+ if (
6069+ name .endswith ("block_sparse_moe.input_linear.weight" ) or
6070+ name .endswith ("shared_mlp.input_linear.weight" )
6071+ ):
6072+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
60916073 return super ().modify_tensors (data_torch , name , bid )
60926074
6093-
60946075 def set_gguf_parameters (self ):
6095- super ().set_gguf_parameters ()
6096- if attention_scale := self .hparams .get ("attention_multiplier" ):
6097- self .gguf_writer .add_attention_scale (attention_scale )
6098- logger .info ("gguf: (granite) attention_scale = %s" , attention_scale )
6099- if embedding_scale := self .hparams .get ("embedding_multiplier" ):
6100- self .gguf_writer .add_embedding_scale (embedding_scale )
6101- logger .info ("gguf: (granite) embedding_scale = %s" , embedding_scale )
6102- if residual_scale := self .hparams .get ("residual_multiplier" ):
6103- self .gguf_writer .add_residual_scale (residual_scale )
6104- logger .info ("gguf: (granite) residual_scale = %s" , residual_scale )
6105- if logits_scale := self .hparams .get ("logits_scaling" ):
6106- self .gguf_writer .add_logit_scale (logits_scale )
6107- logger .info ("gguf: (granite) logits_scale = %s" , logits_scale )
6108- if shared_feed_forward_length := self .hparams .get ("shared_intermediate_size" ):
6109- self .gguf_writer .add_expert_shared_feed_forward_length (shared_feed_forward_length )
6110- logger .info ("gguf: (granitemoeshared) shared_feed_forward_length = %s" , shared_feed_forward_length )
6111- if (n_experts := self .hparams .get ("num_local_experts" )) is not None :
6112- self .gguf_writer .add_expert_count (n_experts )
6113- logger .info (f"gguf: expert count = { n_experts } " )
6114- if (n_experts_used := self .hparams .get ("num_experts_per_tok" )) is not None :
6115- self .gguf_writer .add_expert_used_count (n_experts_used )
6116- logger .info (f"gguf: experts used count = { n_experts_used } " )
6076+ GraniteMoeModel .set_gguf_parameters (self )
6077+ BambaModel .set_gguf_parameters (self )
61176078
61186079 def set_vocab (self ):
61196080 self .hparams ["pad_vocab_size_multiple" ] = 8
0 commit comments