1616from hashlib import sha256
1717from typing import TYPE_CHECKING , Any , Callable , ContextManager , Iterable , Iterator , Literal , Sequence , TypeVar , cast
1818from itertools import chain
19+ from transformers import AutoConfig
1920
2021import math
2122import numpy as np
@@ -66,8 +67,6 @@ class ModelBase:
6667 part_names : list [str ]
6768 is_safetensors : bool
6869 hparams : dict [str , Any ]
69- block_count : int
70- tensor_map : gguf .TensorNameMap
7170 tensor_names : set [str ] | None
7271 gguf_writer : gguf .GGUFWriter
7372 model_name : str | None
@@ -78,6 +77,10 @@ class ModelBase:
7877 # subclasses should define this!
7978 model_arch : gguf .MODEL_ARCH
8079
80+ # subclasses should initialize this!
81+ block_count : int
82+ tensor_map : gguf .TensorNameMap
83+
8184 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , * , is_big_endian : bool = False ,
8285 use_temp_file : bool = False , eager : bool = False ,
8386 metadata_override : Path | None = None , model_name : str | None = None ,
@@ -113,8 +116,6 @@ def get_remote_tensors() -> Iterator[tuple[str, Tensor]]:
113116 if not self .is_safetensors :
114117 self .part_names = ModelBase .get_model_part_names (self .dir_model , "pytorch_model" , ".bin" )
115118 self .hparams = ModelBase .load_hparams (self .dir_model ) if hparams is None else hparams
116- self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
117- self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
118119 self .tensor_names = None
119120 self .metadata_override = metadata_override
120121 self .model_name = model_name
@@ -417,15 +418,13 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
417418
418419 @staticmethod
419420 def load_hparams (dir_model : Path ):
420- with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
421- hparams = json .load (f )
422- architectures = hparams .get ("architectures" )
423- if "text_config" in hparams :
424- hparams = {** hparams , ** hparams ["text_config" ]}
425- if architectures is not None :
426- # preserve "architectures" from root level config
427- hparams ["architectures" ] = architectures
428- return hparams
421+ try :
422+ return AutoConfig .from_pretrained (dir_model ).to_dict ()
423+ except Exception as e :
424+ logger .warning (f"Failed to load model config from { dir_model } : { e } " )
425+ logger .warning ("Trying to load config.json instead" )
426+ with open (dir_model / "config.json" , "r" , encoding = "utf-8" ) as f :
427+ return json .load (f )
429428
430429 @classmethod
431430 def register (cls , * names : str ) -> Callable [[AnyModel ], AnyModel ]:
@@ -454,6 +453,23 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454453
455454
456455class TextModel (ModelBase ):
456+ def __init__ (self , * args , ** kwargs ):
457+ super ().__init__ (* args , ** kwargs )
458+
459+ if "text_config" in self .hparams :
460+ # move the text_config to the root level
461+ self .hparams = {** self .hparams , ** self .hparams ["text_config" ]}
462+
463+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" ])
464+ self .tensor_map = gguf .get_tensor_name_map (self .model_arch , self .block_count )
465+
466+ @classmethod
467+ def __init_subclass__ (cls ):
468+ # can't use an abstract property, because overriding it without type errors
469+ # would require using decorated functions instead of simply defining the property
470+ if "model_arch" not in cls .__dict__ :
471+ raise TypeError (f"Missing property 'model_arch' for { cls .__name__ !r} " )
472+
457473 def set_vocab (self ):
458474 self ._set_vocab_gpt2 ()
459475
@@ -1070,9 +1086,9 @@ def __init__(self, *args, **kwargs):
10701086 if self .model_arch != gguf .MODEL_ARCH .CLIP_VISION :
10711087 raise TypeError ("VisionModel must be subclassed with model_arch = gguf.MODEL_ARCH.CLIP_VISION" )
10721088
1073- # small hack to correct the number of layers
1074- self . tensor_map = gguf . get_tensor_name_map ( gguf . MODEL_ARCH . CLIP_VISION , 128 )
1075- self .n_embd_text = self . find_hparam ([ "hidden_size" , "n_embd" ] )
1089+ # get n_embd of the text model
1090+ text_config = { ** self . hparams , ** self . hparams [ "text_config" ]}
1091+ self .n_embd_text = text_config . get ( "hidden_size" , text_config . get ( "n_embd" , 0 ) )
10761092 assert self .n_embd_text > 0 , "n_embd not found in hparams"
10771093
10781094 if "vision_config" not in self .hparams :
@@ -1081,6 +1097,9 @@ def __init__(self, *args, **kwargs):
10811097 self .global_config = self .hparams
10821098 self .hparams = self .hparams ["vision_config" ]
10831099
1100+ self .block_count = self .find_hparam (["n_layers" , "num_hidden_layers" , "n_layer" , "num_layers" , "depth" ])
1101+ self .tensor_map = gguf .get_tensor_name_map (gguf .MODEL_ARCH .CLIP_VISION , self .block_count )
1102+
10841103 # load preprocessor config
10851104 with open (self .dir_model / "preprocessor_config.json" , "r" , encoding = "utf-8" ) as f :
10861105 self .preprocessor_config = json .load (f )
@@ -1098,7 +1117,7 @@ def set_gguf_parameters(self):
10981117 self .gguf_writer .add_vision_patch_size (self .find_hparam (["patch_size" ]))
10991118 self .gguf_writer .add_vision_embedding_length (self .find_hparam (["hidden_size" ]))
11001119 self .gguf_writer .add_vision_feed_forward_length (self .find_hparam (["intermediate_size" ]))
1101- self .gguf_writer .add_vision_block_count (self .find_hparam ([ "num_hidden_layers" ]) )
1120+ self .gguf_writer .add_vision_block_count (self .block_count )
11021121 self .gguf_writer .add_vision_head_count (self .find_hparam (["num_attention_heads" ]))
11031122
11041123 # preprocessor config
@@ -1719,23 +1738,12 @@ def prepare_tensors(self):
17191738 "LlamaForCausalLM" ,
17201739 "MistralForCausalLM" ,
17211740 "MixtralForCausalLM" ,
1722- "Idefics3ForConditionalGeneration" ,
1723- "SmolVLMForConditionalGeneration" ,
1741+ "VLlama3ForCausalLM" ,
17241742 "LlavaForConditionalGeneration" )
17251743class LlamaModel (TextModel ):
17261744 model_arch = gguf .MODEL_ARCH .LLAMA
17271745 undo_permute = True
17281746
1729- def __init__ (self , * args , ** kwargs ):
1730- super ().__init__ (* args , ** kwargs )
1731- # fix for SmolVLM2, missing `num_attention_heads` in config.json
1732- if self .hparams ["architectures" ][0 ] == "SmolVLMForConditionalGeneration" :
1733- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1734- # fix for Pixtral, missing `num_attention_heads` in config.json
1735- if self .hparams ["architectures" ][0 ] == "LlavaForConditionalGeneration" \
1736- and self .hparams .get ("model_type" ) == "mistral" :
1737- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 32 )
1738-
17391747 def set_vocab (self ):
17401748 try :
17411749 self ._set_vocab_sentencepiece ()
@@ -1898,11 +1906,7 @@ class LlavaVisionModel(VisionModel):
18981906 def __init__ (self , * args , ** kwargs ):
18991907 super ().__init__ (* args , ** kwargs )
19001908 if self .hparams ["model_type" ] == "pixtral" :
1901- # fix missing config.json values
1902- self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 16 )
1903- self .hparams ["num_hidden_layers" ] = self .hparams .get ("num_hidden_layers" , 24 )
1904- self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 4096 )
1905- self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 1024 )
1909+ # layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
19061910 self .hparams ["layer_norm_eps" ] = self .hparams .get ("layer_norm_eps" , 1e-5 )
19071911 self .img_break_tok_id = 12 # see tokenizer_config.json
19081912 else :
@@ -1913,7 +1917,6 @@ def set_gguf_parameters(self):
19131917 hparams = self .hparams
19141918 if hparams ["model_type" ] == "pixtral" :
19151919 self .gguf_writer .add_vision_projector_type (gguf .VisionProjectorType .PIXTRAL )
1916- # default values below are taken from HF tranformers code
19171920 self .gguf_writer .add_vision_attention_layernorm_eps (hparams ["layer_norm_eps" ])
19181921 self .gguf_writer .add_vision_use_silu (True )
19191922
@@ -1944,13 +1947,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
19441947class SmolVLMModel (VisionModel ):
19451948 def __init__ (self , * args , ** kwargs ):
19461949 super ().__init__ (* args , ** kwargs )
1947- # fix for SmolVLM2, missing some keys in config.json
1948- # default values are taken from transformers code
19491950 if self .hparams ["model_type" ] == "smolvlm_vision" :
1951+ # fix for SmolVLM2, missing some keys in config.json
1952+ # default values are taken from transformers code
19501953 self .hparams ["hidden_size" ] = self .hparams .get ("hidden_size" , 1152 )
19511954 self .hparams ["num_attention_heads" ] = self .hparams .get ("num_attention_heads" , 16 )
19521955 self .hparams ["intermediate_size" ] = self .hparams .get ("intermediate_size" , 3072 )
1953- self .hparams ["num_hidden_layers" ] = self .hparams .get ("num_hidden_layers" , 12 )
19541956
19551957 def set_gguf_parameters (self ):
19561958 super ().set_gguf_parameters ()
@@ -3505,6 +3507,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
35053507
35063508@ModelBase .register ("NomicBertModel" )
35073509class NomicBertModel (BertModel ):
3510+ model_arch = gguf .MODEL_ARCH .BERT
3511+
35083512 def __init__ (self , dir_model : Path , ftype : gguf .LlamaFileType , fname_out : Path , ** kwargs : Any ):
35093513 hparams = kwargs .pop ("hparams" , None )
35103514 if hparams is None :
@@ -5849,6 +5853,19 @@ def split_str_to_n_bytes(split_str: str) -> int:
58495853 return n
58505854
58515855
5856+ def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5857+ hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5858+ text_config = hparams .get ("text_config" , {})
5859+ vision_config = hparams .get ("vision_config" , {})
5860+ arch = hparams ["architectures" ][0 ]
5861+ # if "architectures" is found in the sub-config, use that instead
5862+ if model_type == ModelType .TEXT and text_config .get ("architectures" ) is not None :
5863+ arch = text_config ["architectures" ][0 ]
5864+ elif model_type == ModelType .VISION and vision_config .get ("architectures" ) is not None :
5865+ arch = vision_config ["architectures" ][0 ]
5866+ return arch
5867+
5868+
58525869def main () -> None :
58535870 args = parse_args ()
58545871
@@ -5901,16 +5918,15 @@ def main() -> None:
59015918
59025919 logger .info (f"Loading model: { dir_model .name } " )
59035920
5904- hparams = ModelBase .load_hparams (dir_model )
5905-
59065921 if args .mmproj :
59075922 if "mmproj" not in fname_out .name :
59085923 fname_out = ModelBase .add_prefix_to_filename (fname_out , "mmproj-" )
59095924
59105925 with torch .inference_mode ():
59115926 output_type = ftype_map [args .outtype ]
5912- model_architecture = hparams ["architectures" ][0 ]
59135927 model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
5928+ model_architecture = get_model_architecture (dir_model , model_type )
5929+ logger .info (f"Model architecture: { model_architecture } " )
59145930 try :
59155931 model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
59165932 except NotImplementedError :
0 commit comments