@@ -4127,6 +4127,14 @@ def set_gguf_parameters(self):
41274127class MambaModel (TextModel ):
41284128 model_arch = gguf .MODEL_ARCH .MAMBA
41294129
4130+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4131+ # Avoid using AutoConfig for hparams
4132+ hparams = kwargs .pop ("hparams" , None )
4133+ if hparams is None :
4134+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4135+ hparams = json .load (f )
4136+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4137+
41304138 def set_vocab (self ):
41314139 vocab_size = self .hparams ["vocab_size" ]
41324140 # Round vocab size to next multiple of 8
@@ -4205,6 +4213,15 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
42054213class Mamba2Model (TextModel ):
42064214 model_arch = gguf .MODEL_ARCH .MAMBA2
42074215
4216+ def __init__ (self , dir_model : Path , * args , ** kwargs ):
4217+ # Avoid using AutoConfig for hparams
4218+ # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1
4219+ hparams = kwargs .pop ("hparams" , None )
4220+ if hparams is None :
4221+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
4222+ hparams = json .load (f )
4223+ super ().__init__ (dir_model , * args , hparams = hparams , ** kwargs )
4224+
42084225 def set_vocab (self ):
42094226 vocab_size = self .hparams ["vocab_size" ]
42104227 # Round vocab size to next multiple of 16
@@ -5968,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any
59685985 hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
59695986 text_config = hparams .get ("text_config" , {})
59705987 vision_config = hparams .get ("vision_config" , {})
5971- arch = hparams ["architectures" ][0 ]
5988+ arch = None
5989+ if (arches := hparams .get ("architectures" )) is not None and len (arches ) > 0 :
5990+ arch = arches [0 ]
5991+ elif "ssm_cfg" in hparams :
5992+ # For non-hf Mamba and Mamba2 models
5993+ arch = hparams ["ssm_cfg" ].get ("layer" , "Mamba" ) + "ForCausalLM"
5994+
59725995 # if "architectures" is found in the sub-config, use that instead
59735996 if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
59745997 arch = text_config ["architectures" ][0 ]
59755998 elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
59765999 arch = vision_config ["architectures" ][0 ]
6000+ if arch is None :
6001+ raise ValueError ("Failed to detect model architecture" )
59776002 return arch
59786003
59796004
0 commit comments