@@ -4302,6 +4302,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43024302class Mamba2Model (TextModel ):
43034303 model_arch = gguf .MODEL_ARCH .MAMBA2
43044304
4305+ def __init__ (self , * args , ** kwargs ):
4306+ super ().__init__ (* args , ** kwargs )
4307+ self .d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4308+ self .d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4309+ self .n_group = self .hparams .get ("n_groups" , 1 )
4310+
43054311 def set_vocab (self ):
43064312 vocab_size = self .hparams ["vocab_size" ]
43074313 # Round vocab size to next multiple of 16
@@ -4371,10 +4377,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43714377 # (D is also unsqueezed, but for more straightforward broadcast internally)
43724378 data_torch = data_torch .reshape ((* data_torch .shape , 1 ))
43734379 elif self .match_model_tensor_name (new_name , gguf .MODEL_TENSOR .SSM_NORM , bid ):
4374- d_model = self .find_hparam (["hidden_size" , "d_model" , "dim" ])
4375- d_inner = self .find_hparam (["intermediate_size" , "d_inner" ], optional = True ) or 2 * d_model
4376- n_group = self .hparams .get ("n_groups" , 1 )
4377- data_torch = data_torch .reshape ((n_group , d_inner // n_group ))
4380+ data_torch = data_torch .reshape ((self .n_group , self .d_inner // self .n_group ))
43784381
43794382 if name .endswith (".A_log" ):
43804383 logger .debug ("A_log --> A ==> " + new_name )
@@ -4383,6 +4386,107 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
43834386 yield (new_name , data_torch )
43844387
43854388
4389+ @ModelBase .register ("BambaForCausalLM" )
4390+ class BambaModel (Mamba2Model ):
4391+ """Bamba is a hybrid SSM + Attention model that uses Mamba2 SSM layers"""
4392+ model_arch = gguf .MODEL_ARCH .BAMBA
4393+ undo_permute = True
4394+
4395+ def __init__ (self , * args , ** kwargs ):
4396+
4397+ # Hybrid mamba models use a prefix for the mamba-specific params.
4398+ # TODO: Extend this if the prefix(es) need to be configurable
4399+ self .hparam_prefixes = ["mamba" ]
4400+
4401+ super ().__init__ (* args , ** kwargs )
4402+
4403+ # Use Llama conversion for attention
4404+ self ._transformer_model_class : type [TextModel ] = LlamaModel
4405+
4406+ # Lists of which layers use ssm vs attention
4407+ self ._attn_layers = self .hparams .get ("attn_layer_indices" , [])
4408+ if not self ._attn_layers :
4409+ attn_period = self .hparams .get ("attn_layer_period" )
4410+ assert attn_period , "Didn't find attn_layer_indices or attn_layer_period"
4411+ attn_offset = self .hparams .get ("attn_layer_offset" )
4412+ assert attn_offset is not None , "No attention layer offset set with attn_layer_period"
4413+ self ._attn_layers = [
4414+ i for i in range (self .block_count )
4415+ if i % attn_period == attn_offset
4416+ ]
4417+ self ._ssm_layers = [
4418+ i for i in range (self .block_count )
4419+ if i not in self ._attn_layers
4420+ ]
4421+
4422+ # n_group and d_inner are used during reshape_tensors for mamaba2
4423+ self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
4424+ self .n_group = self .find_hparam (["n_groups" ])
4425+ self .d_inner = self .find_hparam (["expand" ]) * self .d_model
4426+
4427+ def find_hparam (self , keys : Iterable [str ], * args , ** kwargs ) -> Any :
4428+ prefixed = []
4429+ for pfx in self .hparam_prefixes :
4430+ prefixed .extend (
4431+ "_" .join ([pfx , k ])
4432+ for k in keys
4433+ )
4434+ keys = list (keys ) + prefixed
4435+ return super ().find_hparam (keys , * args , ** kwargs )
4436+
4437+ def set_gguf_parameters (self ):
4438+
4439+ ## General Params ##
4440+ self .gguf_writer .add_embedding_length (self .d_model )
4441+ self .gguf_writer .add_block_count (self .block_count )
4442+ self .gguf_writer .add_context_length (self .hparams .get ("max_position_embeddings" , 0 ))
4443+ self .gguf_writer .add_vocab_size (self .hparams ["vocab_size" ])
4444+ self .gguf_writer .add_feed_forward_length (self .hparams ["intermediate_size" ])
4445+
4446+ ## Mamba mixer params ##
4447+ self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
4448+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
4449+ self .gguf_writer .add_ssm_group_count (self .n_group )
4450+ self .gguf_writer .add_ssm_inner_size (self .d_inner )
4451+ # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
4452+ # in llama.cpp
4453+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
4454+
4455+ ## Attention params ##
4456+ self .gguf_writer .add_attn_layer_indices (self ._attn_layers )
4457+ self .gguf_writer .add_rope_dimension_count (self .hparams ["attn_rotary_emb" ])
4458+ self .gguf_writer .add_head_count (self .hparams ["num_attention_heads" ])
4459+ self .gguf_writer .add_head_count_kv (self .find_hparam (["num_key_value_heads" , "n_head_kv" ]))
4460+
4461+ ## Feed Forward Params ##
4462+ self .gguf_writer .add_layer_norm_rms_eps (
4463+ self .find_hparam (["layer_norm_epsilon" , "rms_norm_eps" ], optional = True ) or 1e-5
4464+ )
4465+
4466+ ## Validation ##
4467+ d_head = self .find_hparam (["d_head" ], optional = True ) or 64
4468+ assert self .hparams .get ("hidden_act" ) in [None , "silu" ], "Only SILU activation supported"
4469+ assert self .d_inner % d_head == 0 , f"SSM inner size { self .d_inner } not a multiple of head dim { d_head } "
4470+
4471+ def modify_tensors (
4472+ self , data_torch : Tensor , name : str , bid : int | None
4473+ ) -> Iterable [tuple [str , Tensor ]]:
4474+
4475+ # Determine whether this is a mamaba layer or an attention layer
4476+ if bid in self ._ssm_layers :
4477+ for mamba_new_name , data_torch in super ().modify_tensors (
4478+ data_torch , name , bid
4479+ ):
4480+ yield mamba_new_name , data_torch
4481+ elif bid in self ._attn_layers :
4482+ for llama_new_name , data_torch in self ._transformer_model_class .modify_tensors (
4483+ self , data_torch , name , bid
4484+ ):
4485+ yield llama_new_name , data_torch
4486+ else :
4487+ yield self .map_tensor_name (name ), data_torch
4488+
4489+
43864490@ModelBase .register ("CohereForCausalLM" )
43874491class CommandR2Model (TextModel ):
43884492 model_arch = gguf .MODEL_ARCH .COMMAND_R
0 commit comments