@@ -455,8 +455,12 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
455455
456456
457457class TextModel (ModelBase ):
458+ model_type = ModelType .TEXT
459+ hf_arch : str
460+
458461 def __init__ (self , * args , ** kwargs ):
459462 super ().__init__ (* args , ** kwargs )
463+ self .hf_arch = get_model_architecture (self .hparams , self .model_type )
460464
461465 if "text_config" in self .hparams :
462466 # move the text_config to the root level
@@ -1075,10 +1079,36 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
10751079 if (field := vocab_reader .get_field (gguf .Keys .Tokenizer .ADD_EOS )) is not None :
10761080 self .gguf_writer .add_add_eos_token (field .parts [- 1 ].tolist ()[0 ])
10771081
1082+ def _try_set_pooling_type (self ) -> None :
1083+ # get pooling path
1084+ pooling_path = None
1085+ module_path = self .dir_model / "modules.json"
1086+ if module_path .is_file ():
1087+ with open (module_path , encoding = "utf-8" ) as f :
1088+ modules = json .load (f )
1089+ for mod in modules :
1090+ if mod ["type" ] == "sentence_transformers.models.Pooling" :
1091+ pooling_path = mod ["path" ]
1092+ break
1093+
1094+ # get pooling type
1095+ if pooling_path is not None :
1096+ with open (self .dir_model / pooling_path / "config.json" , encoding = "utf-8" ) as f :
1097+ pooling = json .load (f )
1098+ if pooling ["pooling_mode_mean_tokens" ]:
1099+ pooling_type = gguf .PoolingType .MEAN
1100+ elif pooling ["pooling_mode_cls_token" ]:
1101+ pooling_type = gguf .PoolingType .CLS
1102+ elif pooling ["pooling_mode_lasttoken" ]:
1103+ pooling_type = gguf .PoolingType .LAST
1104+ else :
1105+ raise NotImplementedError ("Only MEAN, CLS, and LAST pooling types supported" )
1106+ self .gguf_writer .add_pooling_type (pooling_type )
1107+
10781108
10791109class VisionModel (ModelBase ):
1110+ model_type = ModelType .VISION
10801111 model_arch = gguf .MODEL_ARCH .CLIP_VISION
1081- n_text_embd = 0
10821112 preprocessor_config : dict [str , Any ]
10831113 global_config : dict [str , Any ]
10841114
@@ -2542,7 +2572,7 @@ def set_gguf_parameters(self):
25422572 self .gguf_writer .add_file_type (self .ftype )
25432573
25442574
2545- @ModelBase .register ("Qwen2ForCausalLM" )
2575+ @ModelBase .register ("Qwen2Model" , " Qwen2ForCausalLM" )
25462576class Qwen2Model (TextModel ):
25472577 model_arch = gguf .MODEL_ARCH .QWEN2
25482578
@@ -2554,12 +2584,18 @@ def set_vocab(self):
25542584
25552585 def set_gguf_parameters (self ):
25562586 super ().set_gguf_parameters ()
2587+ self ._try_set_pooling_type ()
25572588 if self .hparams .get ("rope_scaling" ) is not None and "factor" in self .hparams ["rope_scaling" ]:
25582589 if self .hparams ["rope_scaling" ].get ("type" ) == "yarn" :
25592590 self .gguf_writer .add_rope_scaling_type (gguf .RopeScalingType .YARN )
25602591 self .gguf_writer .add_rope_scaling_factor (self .hparams ["rope_scaling" ]["factor" ])
25612592 self .gguf_writer .add_rope_scaling_orig_ctx_len (self .hparams ["rope_scaling" ]["original_max_position_embeddings" ])
25622593
2594+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
2595+ if self .hf_arch == "Qwen2Model" :
2596+ name = f"model.{ name } " # map to Qwen2ForCausalLM tensors
2597+ yield from super ().modify_tensors (data_torch , name , bid )
2598+
25632599
25642600@ModelBase .register ("Qwen2VLForConditionalGeneration" , "Qwen2_5_VLForConditionalGeneration" )
25652601class Qwen2VLModel (TextModel ):
@@ -3396,29 +3432,7 @@ def __init__(self, *args, **kwargs):
33963432 def set_gguf_parameters (self ):
33973433 super ().set_gguf_parameters ()
33983434 self .gguf_writer .add_causal_attention (False )
3399-
3400- # get pooling path
3401- pooling_path = None
3402- module_path = self .dir_model / "modules.json"
3403- if module_path .is_file ():
3404- with open (module_path , encoding = "utf-8" ) as f :
3405- modules = json .load (f )
3406- for mod in modules :
3407- if mod ["type" ] == "sentence_transformers.models.Pooling" :
3408- pooling_path = mod ["path" ]
3409- break
3410-
3411- # get pooling type
3412- if pooling_path is not None :
3413- with open (self .dir_model / pooling_path / "config.json" , encoding = "utf-8" ) as f :
3414- pooling = json .load (f )
3415- if pooling ["pooling_mode_mean_tokens" ]:
3416- pooling_type = gguf .PoolingType .MEAN
3417- elif pooling ["pooling_mode_cls_token" ]:
3418- pooling_type = gguf .PoolingType .CLS
3419- else :
3420- raise NotImplementedError ("Only MEAN and CLS pooling types supported" )
3421- self .gguf_writer .add_pooling_type (pooling_type )
3435+ self ._try_set_pooling_type ()
34223436
34233437 def set_vocab (self ):
34243438 tokens , toktypes , tokpre = self .get_vocab_base ()
@@ -5962,8 +5976,7 @@ def split_str_to_n_bytes(split_str: str) -> int:
59625976 return n
59635977
59645978
5965- def get_model_architecture (dir_model : Path , model_type : ModelType , hparams : Any = None ) -> str :
5966- hparams = ModelBase .load_hparams (dir_model ) if hparams is None else hparams
5979+ def get_model_architecture (hparams : dict [str , Any ], model_type : ModelType ) -> str :
59675980 text_config = hparams .get ("text_config" , {})
59685981 vision_config = hparams .get ("vision_config" , {})
59695982 arch = hparams ["architectures" ][0 ]
@@ -6034,7 +6047,8 @@ def main() -> None:
60346047 with torch .inference_mode ():
60356048 output_type = ftype_map [args .outtype ]
60366049 model_type = ModelType .VISION if args .mmproj else ModelType .TEXT
6037- model_architecture = get_model_architecture (dir_model , model_type )
6050+ hparams = ModelBase .load_hparams (dir_model )
6051+ model_architecture = get_model_architecture (hparams , model_type )
60386052 logger .info (f"Model architecture: { model_architecture } " )
60396053 try :
60406054 model_class = ModelBase .from_model_architecture (model_architecture , model_type = model_type )
0 commit comments