@@ -7546,9 +7546,13 @@ def __init__(self, *args, **kwargs):
75467546 ]
75477547
75487548 # n_group and d_inner are used during reshape_tensors for mamba2
7549- self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
7550- self .n_group = self .find_hparam (["n_groups" ])
7551- self .d_inner = self .find_hparam (["expand" ]) * self .d_model
7549+ # NOTE: Explicitly include hparam prefix prefix for d_model to
7550+ # disambiguate with top-level head_dim
7551+ # NOTE 2: If needed for future models, this can be isolated in a method
7552+ # to separate the prefix setting and teh keys used
7553+ self .d_model = self .find_hparam ([f"{ self .hparam_prefixes [0 ]} _head_dim" , "hidden_size" , "d_model" ])
7554+ self .n_group = self .find_hparam (["n_groups" , "num_groups" ])
7555+ self .d_inner = self .find_hparam (["expand" , "num_heads" ]) * self .d_model
75527556
75537557 def get_attn_layers (self ):
75547558 # Explicit list of layer type names
@@ -7609,12 +7613,12 @@ def set_gguf_parameters(self):
76097613
76107614 ## Mamba mixer params ##
76117615 self .gguf_writer .add_ssm_conv_kernel (self .find_hparam (["conv_kernel" , "d_conv" ]))
7612- self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" ]))
7616+ self .gguf_writer .add_ssm_state_size (self .find_hparam (["state_size" , "d_state" , "state_dim" , "ssm_state_size" ]))
76137617 self .gguf_writer .add_ssm_group_count (self .n_group )
76147618 self .gguf_writer .add_ssm_inner_size (self .d_inner )
76157619 # NOTE: The mamba_dt_rank is _not_ the right field for how this is used
76167620 # in llama.cpp
7617- self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" ]))
7621+ self .gguf_writer .add_ssm_time_step_rank (self .find_hparam (["n_heads" , "num_heads" ]))
76187622
76197623 ## Attention params ##
76207624 head_count_kv = self .find_hparam (["num_key_value_heads" , "n_head_kv" ])
@@ -7641,6 +7645,55 @@ def set_vocab(self):
76417645 Mamba2Model .set_vocab (self )
76427646
76437647
7648+ @ModelBase .register ("NemotronHForCausalLM" )
7649+ class NemotronHModel (GraniteHybridModel ):
7650+ """Hybrid mamba2/attention model from NVIDIA"""
7651+ model_arch = gguf .MODEL_ARCH .NEMOTRON_H
7652+
7653+ def __init__ (self , * args , ** kwargs ):
7654+ super ().__init__ (* args , ** kwargs )
7655+
7656+ # Save the top-level head_dim for later
7657+ self .head_dim = self .hparams .get ("head_dim" , self .hparams .get ("attention_head_dim" ))
7658+ assert self .head_dim is not None , "Could not find the attention head dim in config"
7659+
7660+ # Don't use expand to calculate d_inner
7661+ self .d_inner = self .find_hparam (["num_heads" ]) * self .d_model
7662+
7663+ # Update the ssm / attn / mlp layers
7664+ # M: Mamba2, *: Attention, -: MLP
7665+ hybrid_override_pattern = self .hparams ["hybrid_override_pattern" ]
7666+ self ._ssm_layers = [i for i , val in enumerate (hybrid_override_pattern ) if val == "M" ]
7667+ self ._mlp_layers = [i for i , val in enumerate (hybrid_override_pattern ) if val == "-" ]
7668+
7669+ def get_attn_layers (self ):
7670+ hybrid_override_pattern = self .hparams ["hybrid_override_pattern" ]
7671+ assert len (hybrid_override_pattern ) == self .block_count , "Mismatch between hybrid override and num_hidden_layers!"
7672+ return [i for i , val in enumerate (hybrid_override_pattern ) if val == "*" ]
7673+
7674+ def set_gguf_parameters (self ):
7675+ super ().set_gguf_parameters ()
7676+
7677+ self .gguf_writer .add_key_length (self .head_dim )
7678+ self .gguf_writer .add_value_length (self .head_dim )
7679+
7680+ # Set feed_forward_length
7681+ # NOTE: This will trigger an override warning. This is preferrable to
7682+ # duplicating all the parent logic
7683+ n_ff = self .find_hparam (["intermediate_size" , "n_inner" , "hidden_dim" ])
7684+ self .gguf_writer .add_feed_forward_length ([
7685+ n_ff if i in self ._mlp_layers else 0 for i in range (self .block_count )
7686+ ])
7687+
7688+ def set_vocab (self ):
7689+ super ().set_vocab ()
7690+
7691+ # The tokenizer _does_ add a BOS token (via post_processor type
7692+ # TemplateProcessing) but does not set add_bos_token to true in the
7693+ # config, so we need to explicitly override it here.
7694+ self .gguf_writer .add_add_bos_token (True )
7695+
7696+
76447697@ModelBase .register ("BailingMoeForCausalLM" )
76457698class BailingMoeModel (TextModel ):
76467699 model_arch = gguf .MODEL_ARCH .BAILINGMOE
0 commit comments