@@ -5056,6 +5056,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
50565056 with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
50575057 hparams = json .load (f )
50585058 super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
5059+ self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
5060+ self .d_inner = self .find_hparam (["mamba_d_ssm" , "intermediate_size" , "d_inner" ], optional = True ) or 2 * self .d_model
5061+ self .n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
50595062
50605063 def set_vocab (self ):
50615064 vocab_size = self .hparams ["vocab_size" ]
@@ -5078,32 +5081,29 @@ def set_vocab(self):
50785081 self ._set_vocab_builtin ("gpt-neox" , vocab_size )
50795082
50805083 def set_gguf_parameters (self ):
5081- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
5082- d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
5083- d_inner = self .find_hparam (["mamba_d_ssm" , "intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
5084- d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
5085- head_dim = self .find_hparam (["mamba_d_head" , "head_dim" ], optional = True ) or 64
5086- n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
5084+ d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
5085+ d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
5086+ head_dim = self .find_hparam (["mamba_d_head" , "head_dim" ], optional = True ) or 64
50875087
50885088 rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
50895089
50905090 # Fail early for models which don't have a block expansion factor of 2
50915091 # TODO: does this really matter?
50925092 # skip the assertion for FalconH1 Model
50935093 if self .model_arch != gguf .MODEL_ARCH .FALCON_H1 :
5094- assert d_inner == 2 * d_model
5095- assert d_inner % head_dim == 0
5094+ assert self . d_inner == 2 * self . d_model
5095+ assert self . d_inner % head_dim == 0
50965096
50975097 self .gguf_writer .add_context_length (2 ** 20 ) # arbitrary value; for those who use the default
5098- self .gguf_writer .add_embedding_length (d_model )
5098+ self .gguf_writer .add_embedding_length (self . d_model )
50995099 self .gguf_writer .add_feed_forward_length (0 ) # unused, but seemingly required when loading
51005100 self .gguf_writer .add_head_count (0 ) # unused, but seemingly required when loading
51015101 self .gguf_writer .add_block_count (self .block_count )
51025102 self .gguf_writer .add_ssm_conv_kernel (d_conv )
5103- self .gguf_writer .add_ssm_inner_size (d_inner )
5103+ self .gguf_writer .add_ssm_inner_size (self . d_inner )
51045104 self .gguf_writer .add_ssm_state_size (d_state )
5105- self .gguf_writer .add_ssm_time_step_rank (d_inner // head_dim )
5106- self .gguf_writer .add_ssm_group_count (n_group )
5105+ self .gguf_writer .add_ssm_time_step_rank (self . d_inner // head_dim )
5106+ self .gguf_writer .add_ssm_group_count (self . n_group )
51075107 self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
51085108 self .gguf_writer .add_file_type (self .ftype )
51095109
@@ -5128,10 +5128,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
51285128 # (D is also unsqueezed, but for more straightforward broadcast internally)
51295129 data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
51305130 elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
5131- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
5132- d_inner = self .find_hparam (["mamba_d_ssm" , "intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
5133- n_group = self .hparams .get ("n_groups" , 1 )
5134- data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
5131+ data_torch = data_torch .reshape ((self .n_group , self .d_inner // self .n_group ))
51355132
51365133 if name .endswith (".A_log" ):
51375134 logger .debug ("A_log --> A ==> " + new_name )
@@ -6618,18 +6615,148 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
66186615 (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_EXP , bid ), up ),
66196616 ]
66206617
6618+ has_experts = bool (self .hparams .get ('num_local_experts' ))
6619+
66216620 if name .endswith ("shared_mlp.input_linear.weight" ):
66226621 ffn_dim = self .hparams ["shared_intermediate_size" ]
66236622 assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * shared_intermediate_size"
66246623 gate , up = data_torch .split (ffn_dim , dim = - 2 )
6624+ if has_experts :
6625+ return [
6626+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6627+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6628+ ]
66256629 return [
6626- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6627- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6630+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE , bid ), gate ),
6631+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP , bid ), up ),
6632+ ]
6633+
6634+ if not has_experts and name .endswith ("shared_mlp.output_linear.weight" ):
6635+ return [
6636+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_DOWN , bid ), data_torch )
66286637 ]
66296638
66306639 return super ().modify_tensors (data_torch , name , bid )
66316640
66326641
6642+ @ModelBase .register ("GraniteMoeHybridForCausalLM" , "BambaForCausalLM" )
6643+ class GraniteHybridModel (Mamba2Model , GraniteMoeModel ):
6644+ """GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM
6645+ layers and optionally uses MoE w/ a shared expert"""
6646+ model_arch = gguf .MODEL_ARCH .GRANITE_HYBRID
6647+ undo_permute = True
6648+
6649+ def __init__ (self , * args , ** kwargs ):
6650+
6651+ # Hybrid mamba models use a prefix for the mamba-specific params.
6652+ # TODO: Extend this if the prefix(es) need to be configurable
6653+ self .hparam_prefixes = ["mamba" ]
6654+
6655+ super ().__init__ (* args , ** kwargs )
6656+
6657+ # Lists of which layers use ssm vs attention
6658+ self ._attn_layers = self .get_attn_layers ()
6659+ self ._ssm_layers = [
6660+ i for i in range (self .block_count )
6661+ if i not in self ._attn_layers
6662+ ]
6663+
6664+ # n_group and d_inner are used during reshape_tensors for mamba2
6665+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
6666+ self .n_group = self .find_hparam (["n_groups" ])
6667+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
6668+
6669+ def get_attn_layers (self ):
6670+ # Explicit list of layer type names
6671+ if layer_types := self .hparams .get ("layer_types" ):
6672+ return [
6673+ i for i , typ in enumerate (layer_types )
6674+ if typ == "attention"
6675+ ]
6676+
6677+ # Layer types indicated by index or period
6678+ attn_layers = self .hparams .get ("attn_layer_indices" , [])
6679+ if not attn_layers :
6680+ attn_period = self .hparams .get ("attn_layer_period" )
6681+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
6682+ attn_offset = self .hparams .get ("attn_layer_offset" )
6683+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
6684+ attn_layers = [
6685+ i for i in range (self .block_count )
6686+ if i % attn_period == attn_offset
6687+ ]
6688+ return attn_layers
6689+
6690+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6691+ prefixed = []
6692+ for pfx in self .hparam_prefixes :
6693+ prefixed .extend (
6694+ "_" .join ([pfx , k ])
6695+ for k in keys
6696+ )
6697+ keys = list (keys ) + prefixed
6698+ return Mamba2Model .find_hparam (self , keys , * args , ** kwargs )
6699+
6700+ def modify_tensors (
6701+ self , data_torch : Tensor , name : str , bid : int | None
6702+ ) -> Iterable [tuple [str , Tensor ]]:
6703+ if (
6704+ name .endswith ("block_sparse_moe.input_linear.weight" )
6705+ or "shared_mlp" in name
6706+ ):
6707+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
6708+
6709+ # Determine whether this is a mamba layer or an attention layer
6710+ if bid in self ._ssm_layers :
6711+ return Mamba2Model .modify_tensors (self , data_torch , name , bid )
6712+ elif bid in self ._attn_layers :
6713+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
6714+ return [(self .map_tensor_name (name ), data_torch )]
6715+
6716+ def set_gguf_parameters (self ):
6717+ """This method merges params from both parents and some that are
6718+ specific to this model. The result is some duplication of how the params
6719+ get set. The following warnings are expected during conversion:
6720+
6721+ WARNING:Duplicated key name 'granitehybrid.attention.head_count_kv'
6722+ WARNING:Duplicated key name 'granitehybrid.context_length'
6723+ """
6724+ GraniteMoeModel .set_gguf_parameters (self )
6725+
6726+ ## Mamba mixer params ##
6727+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
6728+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
6729+ self .gguf_writer .add_ssm_group_count (self .n_group )
6730+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
6731+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
6732+ # in llama.cpp
6733+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
6734+
6735+ ## Attention params ##
6736+ head_count_kv = self .find_hparam (["num_key_value_heads" , "n_head_kv" ])
6737+ head_count_kv_vec = [
6738+ head_count_kv if i in self ._attn_layers else 0 for i in range (self .block_count )
6739+ ]
6740+ if rope_dim := self .hparams .get ("attn_rotary_emb" ):
6741+ self .gguf_writer .add_rope_dimension_count (rope_dim )
6742+ self .gguf_writer .add_head_count_kv (head_count_kv_vec )
6743+
6744+ ## If Bamba, use rope, otherwise don't
6745+ use_rope = "BambaForCausalLM" in self .hparams ["architectures" ]
6746+ self .gguf_writer .add_rope_scaling_finetuned (use_rope )
6747+ if not use_rope :
6748+ self .gguf_writer .add_context_length (2 ** 20 )
6749+
6750+ ## Validation ##
6751+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
6752+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6753+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
6754+
6755+ def set_vocab (self ):
6756+ self .hparams ["pad_vocab_size_multiple" ] = 8
6757+ Mamba2Model .set_vocab (self )
6758+
6759+
66336760@ModelBase .register ("BailingMoeForCausalLM" )
66346761class BailingMoeModel (TextModel ):
66356762 model_arch = gguf .MODEL_ARCH .BAILINGMOE
@@ -6853,7 +6980,7 @@ def __init__(self, *args, **kwargs):
68536980 # Use Llama conversion for attention
68546981 self ._transformer_model_class = LlamaModel
68556982
6856- # n_group and d_inner are used during reshape_tensors for mamaba2
6983+ # n_group and d_inner are used during reshape_tensors for mamba2
68576984 self .n_group = self .find_hparam (["n_groups" ])
68586985 self .d_inner = self .find_hparam (["mamba_d_ssm" ])
68596986 self .d_head = self .find_hparam (["d_head" ])
0 commit comments