@@ -4894,6 +4894,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
48944894 with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
48954895 hparams = json .load (f )
48964896 super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4897+ self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4898+ self .d_inner = self .find_hparam (["mamba_d_ssm" , "intermediate_size" , "d_inner" ], optional = True ) or 2 * self .d_model
4899+ self .n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
48974900
48984901 def set_vocab (self ):
48994902 vocab_size = self .hparams ["vocab_size" ]
@@ -4916,32 +4919,29 @@ def set_vocab(self):
49164919 self ._set_vocab_builtin ("gpt-neox" , vocab_size )
49174920
49184921 def set_gguf_parameters (self ):
4919- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4920- d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
4921- d_inner = self .find_hparam (["mamba_d_ssm" , "intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4922- d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
4923- head_dim = self .find_hparam (["mamba_d_head" , "head_dim" ], optional = True ) or 64
4924- n_group = self .find_hparam (["n_groups" ], optional = True ) or 1
4922+ d_conv = self .find_hparam (["conv_kernel" , "d_conv" ], optional = True ) or 4
4923+ d_state = self .find_hparam (["state_size" , "d_state" ], optional = True ) or 128
4924+ head_dim = self .find_hparam (["mamba_d_head" , "head_dim" ], optional = True ) or 64
49254925
49264926 rms_norm_eps = self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
49274927
49284928 # Fail early for models which don't have a block expansion factor of 2
49294929 # TODO: does this really matter?
49304930 # skip the assertion for FalconH1 Model
49314931 if self .model_arch != gguf .MODEL_ARCH .FALCON_H1 :
4932- assert d_inner == 2 * d_model
4933- assert d_inner % head_dim == 0
4932+ assert self . d_inner == 2 * self . d_model
4933+ assert self . d_inner % head_dim == 0
49344934
49354935 self .gguf_writer .add_context_length (2 ** 20 ) # arbitrary value; for those who use the default
4936- self .gguf_writer .add_embedding_length (d_model )
4936+ self .gguf_writer .add_embedding_length (self . d_model )
49374937 self .gguf_writer .add_feed_forward_length (0 ) # unused, but seemingly required when loading
49384938 self .gguf_writer .add_head_count (0 ) # unused, but seemingly required when loading
49394939 self .gguf_writer .add_block_count (self .block_count )
49404940 self .gguf_writer .add_ssm_conv_kernel (d_conv )
4941- self .gguf_writer .add_ssm_inner_size (d_inner )
4941+ self .gguf_writer .add_ssm_inner_size (self . d_inner )
49424942 self .gguf_writer .add_ssm_state_size (d_state )
4943- self .gguf_writer .add_ssm_time_step_rank (d_inner // head_dim )
4944- self .gguf_writer .add_ssm_group_count (n_group )
4943+ self .gguf_writer .add_ssm_time_step_rank (self . d_inner // head_dim )
4944+ self .gguf_writer .add_ssm_group_count (self . n_group )
49454945 self .gguf_writer .add_layer_norm_rms_eps (rms_norm_eps )
49464946 self .gguf_writer .add_file_type (self .ftype )
49474947
@@ -4966,10 +4966,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49664966 # (D is also unsqueezed, but for more straightforward broadcast internally)
49674967 data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
49684968 elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
4969- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4970- d_inner = self .find_hparam (["mamba_d_ssm" , "intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4971- n_group = self .hparams .get ("n_groups" , 1 )
4972- data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
4969+ data_torch = data_torch .reshape ((self .n_group , self .d_inner // self .n_group ))
49734970
49744971 if name .endswith (".A_log" ):
49754972 logger .debug ("A_log --> A ==> " + new_name )
@@ -6456,18 +6453,148 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
64566453 (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_EXP , bid ), up ),
64576454 ]
64586455
6456+ has_experts = bool (self .hparams .get ('num_local_experts' ))
6457+
64596458 if name .endswith ("shared_mlp.input_linear.weight" ):
64606459 ffn_dim = self .hparams ["shared_intermediate_size" ]
64616460 assert data_torch .shape [- 2 ] == 2 * ffn_dim , "Merged FFN tensor size must be 2 * shared_intermediate_size"
64626461 gate , up = data_torch .split (ffn_dim , dim = - 2 )
6462+ if has_experts :
6463+ return [
6464+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6465+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6466+ ]
64636467 return [
6464- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE_SHEXP , bid ), gate ),
6465- (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP_SHEXP , bid ), up ),
6468+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_GATE , bid ), gate ),
6469+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_UP , bid ), up ),
6470+ ]
6471+
6472+ if not has_experts and name .endswith ("shared_mlp.output_linear.weight" ):
6473+ return [
6474+ (self .format_tensor_name (gguf .MODEL_TENSOR .FFN_DOWN , bid ), data_torch )
64666475 ]
64676476
64686477 return super ().modify_tensors (data_torch , name , bid )
64696478
64706479
6480+ @ModelBase .register ("GraniteMoeHybridForCausalLM" , "BambaForCausalLM" )
6481+ class GraniteHybridModel (Mamba2Model , GraniteMoeModel ):
6482+ """GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM
6483+ layers and optionally uses MoE w/ a shared expert"""
6484+ model_arch = gguf .MODEL_ARCH .GRANITE_HYBRID
6485+ undo_permute = True
6486+
6487+ def __init__ (self , * args , ** kwargs ):
6488+
6489+ # Hybrid mamba models use a prefix for the mamba-specific params.
6490+ # TODO: Extend this if the prefix(es) need to be configurable
6491+ self .hparam_prefixes = ["mamba" ]
6492+
6493+ super ().__init__ (* args , ** kwargs )
6494+
6495+ # Lists of which layers use ssm vs attention
6496+ self ._attn_layers = self .get_attn_layers ()
6497+ self ._ssm_layers = [
6498+ i for i in range (self .block_count )
6499+ if i not in self ._attn_layers
6500+ ]
6501+
6502+ # n_group and d_inner are used during reshape_tensors for mamba2
6503+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
6504+ self .n_group = self .find_hparam (["n_groups" ])
6505+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
6506+
6507+ def get_attn_layers (self ):
6508+ # Explicit list of layer type names
6509+ if layer_types := self .hparams .get ("layer_types" ):
6510+ return [
6511+ i for i , typ in enumerate (layer_types )
6512+ if typ == "attention"
6513+ ]
6514+
6515+ # Layer types indicated by index or period
6516+ attn_layers = self .hparams .get ("attn_layer_indices" , [])
6517+ if not attn_layers :
6518+ attn_period = self .hparams .get ("attn_layer_period" )
6519+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
6520+ attn_offset = self .hparams .get ("attn_layer_offset" )
6521+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
6522+ attn_layers = [
6523+ i for i in range (self .block_count )
6524+ if i % attn_period == attn_offset
6525+ ]
6526+ return attn_layers
6527+
6528+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
6529+ prefixed = []
6530+ for pfx in self .hparam_prefixes :
6531+ prefixed .extend (
6532+ "_" .join ([pfx , k ])
6533+ for k in keys
6534+ )
6535+ keys = list (keys ) + prefixed
6536+ return Mamba2Model .find_hparam (self , keys , * args , ** kwargs )
6537+
6538+ def modify_tensors (
6539+ self , data_torch : Tensor , name : str , bid : int | None
6540+ ) -> Iterable [tuple [str , Tensor ]]:
6541+ if (
6542+ name .endswith ("block_sparse_moe.input_linear.weight" )
6543+ or "shared_mlp" in name
6544+ ):
6545+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
6546+
6547+ # Determine whether this is a mamba layer or an attention layer
6548+ if bid in self ._ssm_layers :
6549+ return Mamba2Model .modify_tensors (self , data_torch , name , bid )
6550+ elif bid in self ._attn_layers :
6551+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
6552+ return [(self .map_tensor_name (name ), data_torch )]
6553+
6554+ def set_gguf_parameters (self ):
6555+ """This method merges params from both parents and some that are
6556+ specific to this model. The result is some duplication of how the params
6557+ get set. The following warnings are expected during conversion:
6558+
6559+ WARNING:Duplicated key name 'granitehybrid.attention.head_count_kv'
6560+ WARNING:Duplicated key name 'granitehybrid.context_length'
6561+ """
6562+ GraniteMoeModel .set_gguf_parameters (self )
6563+
6564+ ## Mamba mixer params ##
6565+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
6566+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
6567+ self .gguf_writer .add_ssm_group_count (self .n_group )
6568+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
6569+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
6570+ # in llama.cpp
6571+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
6572+
6573+ ## Attention params ##
6574+ head_count_kv = self .find_hparam (["num_key_value_heads" , "n_head_kv" ])
6575+ head_count_kv_vec = [
6576+ head_count_kv if i in self ._attn_layers else 0 for i in range (self .block_count )
6577+ ]
6578+ if rope_dim := self .hparams .get ("attn_rotary_emb" ):
6579+ self .gguf_writer .add_rope_dimension_count (rope_dim )
6580+ self .gguf_writer .add_head_count_kv (head_count_kv_vec )
6581+
6582+ ## If Bamba, use rope, otherwise don't
6583+ use_rope = "BambaForCausalLM" in self .hparams ["architectures" ]
6584+ self .gguf_writer .add_rope_scaling_finetuned (use_rope )
6585+ if not use_rope :
6586+ self .gguf_writer .add_context_length (2 ** 20 )
6587+
6588+ ## Validation ##
6589+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
6590+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
6591+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
6592+
6593+ def set_vocab (self ):
6594+ self .hparams ["pad_vocab_size_multiple" ] = 8
6595+ Mamba2Model .set_vocab (self )
6596+
6597+
64716598@ModelBase .register ("BailingMoeForCausalLM" )
64726599class BailingMoeModel (TextModel ):
64736600 model_arch = gguf .MODEL_ARCH .BAILINGMOE
@@ -6691,7 +6818,7 @@ def __init__(self, *args, **kwargs):
66916818 # Use Llama conversion for attention
66926819 self ._transformer_model_class = LlamaModel
66936820
6694- # n_group and d_inner are used during reshape_tensors for mamaba2
6821+ # n_group and d_inner are used during reshape_tensors for mamba2
66956822 self .n_group = self .find_hparam (["n_groups" ])
66966823 self .d_inner = self .find_hparam (["mamba_d_ssm" ])
66976824 self .d_head = self .find_hparam (["d_head" ])
0 commit comments