@@ -6488,22 +6488,19 @@ def __init__(self, *args, **kwargs):
64886488
64896489 super ().__init__ (* args , ** kwargs )
64906490
6491- # Use Granite conversion for attention
6492- self ._transformer_model_class : type [TextModel ] = GraniteModel
6493-
64946491 # Lists of which layers use ssm vs attention
6495- self ._attn_layers = self .get_attn_layres ()
6492+ self ._attn_layers = self .get_attn_layers ()
64966493 self ._ssm_layers = [
64976494 i for i in range (self .block_count )
64986495 if i not in self ._attn_layers
64996496 ]
65006497
6501- # n_group and d_inner are used during reshape_tensors for mamaba2
6498+ # n_group and d_inner are used during reshape_tensors for mamba2
65026499 self .d_model = self .find_hparam (["hidden_size" , "d_model" ])
65036500 self .n_group = self .find_hparam (["n_groups" ])
65046501 self .d_inner = self .find_hparam (["expand" ]) * self .d_model
65056502
6506- def get_attn_layres (self ):
6503+ def get_attn_layers (self ):
65076504 # Explicit list of layer type names
65086505 if layer_types := self .hparams .get ("layer_types" ):
65096506 return [
@@ -6532,7 +6529,7 @@ def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
65326529 for k in keys
65336530 )
65346531 keys = list (keys ) + prefixed
6535- return super () .find_hparam (keys , * args , ** kwargs )
6532+ return Mamba2Model .find_hparam (self , keys , * args , ** kwargs )
65366533
65376534 def modify_tensors (
65386535 self , data_torch : Tensor , name : str , bid : int | None
@@ -6543,11 +6540,11 @@ def modify_tensors(
65436540 ):
65446541 return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
65456542
6546- # Determine whether this is a mamaba layer or an attention layer
6543+ # Determine whether this is a mamba layer or an attention layer
65476544 if bid in self ._ssm_layers :
6548- return super () .modify_tensors (data_torch , name , bid )
6545+ return Mamba2Model .modify_tensors (self , data_torch , name , bid )
65496546 elif bid in self ._attn_layers :
6550- return self . _transformer_model_class .modify_tensors (self , data_torch , name , bid )
6547+ return GraniteMoeModel .modify_tensors (self , data_torch , name , bid )
65516548 return [(self .map_tensor_name (name ), data_torch )]
65526549
65536550 def set_gguf_parameters (self ):
@@ -6595,7 +6592,7 @@ def set_gguf_parameters(self):
65956592
65966593 def set_vocab (self ):
65976594 self .hparams ["pad_vocab_size_multiple" ] = 8
6598- super () .set_vocab ()
6595+ Mamba2Model .set_vocab (self )
65996596
66006597
66016598@ModelBase .register ("BailingMoeForCausalLM" )
@@ -6821,7 +6818,7 @@ def __init__(self, *args, **kwargs):
68216818 # Use Llama conversion for attention
68226819 self ._transformer_model_class = LlamaModel
68236820
6824- # n_group and d_inner are used during reshape_tensors for mamaba2
6821+ # n_group and d_inner are used during reshape_tensors for mamba2
68256822 self .n_group = self .find_hparam (["n_groups" ])
68266823 self .d_inner = self .find_hparam (["mamba_d_ssm" ])
68276824 self .d_head = self .find_hparam (["d_head" ])
0 commit comments