@@ -3476,7 +3476,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
34763476 yield from super ().modify_tensors (data_torch , name , bid )
34773477
34783478
3479- @ModelBase .register ("Ernie4_5_ForCausalLM" )
3479+ @ModelBase .register ("Ernie4_5_ForCausalLM" , "Ernie4_5ForCausalLM" )
34803480class Ernie4_5Model (TextModel ):
34813481 model_arch = gguf .MODEL_ARCH .ERNIE4_5
34823482
@@ -6566,9 +6566,11 @@ def prepare_tensors(self):
65666566 raise ValueError (f"Unprocessed experts: { experts } " )
65676567
65686568
6569- @ModelBase .register ("DeepseekV2ForCausalLM" )
6570- @ModelBase .register ("DeepseekV3ForCausalLM" )
6571- @ModelBase .register ("KimiVLForConditionalGeneration" )
6569+ @ModelBase .register (
6570+ "DeepseekV2ForCausalLM" ,
6571+ "DeepseekV3ForCausalLM" ,
6572+ "KimiVLForConditionalGeneration" ,
6573+ )
65726574class DeepseekV2Model (TextModel ):
65736575 model_arch = gguf .MODEL_ARCH .DEEPSEEK2
65746576
@@ -8813,6 +8815,43 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
88138815 return "mm.2.weight"
88148816 return super ().map_tensor_name (name , try_suffixes )
88158817
8818+
8819+ @ModelBase .register ("KimiVLForConditionalGeneration" )
8820+ class KimiVLModel (MmprojModel ):
8821+ def __init__ (self , * args , ** kwargs ):
8822+ super ().__init__ (* args , ** kwargs )
8823+ assert self .hparams_vision is not None
8824+ self .hparams_vision ["image_size" ] = 64 * 14 # for compatibility
8825+
8826+ def set_gguf_parameters (self ):
8827+ super ().set_gguf_parameters ()
8828+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .KIMIVL )
8829+ self .gguf_writer .add_vision_use_gelu (True )
8830+ self .gguf_writer .add_vision_projector_scale_factor (2 )
8831+ # eps is the same as pytorch's default value
8832+ assert self .hparams_vision is not None
8833+ self .gguf_writer .add_vision_attention_layernorm_eps (self .hparams_vision .get ("layer_norm_eps" , 1e-5 ))
8834+
8835+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8836+ del bid # unused
8837+ is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
8838+
8839+ if is_vision_tensor :
8840+ if "pos_emb.weight" in name :
8841+ data_torch = data_torch .view (data_torch .shape [0 ] * data_torch .shape [1 ], data_torch .shape [2 ])
8842+ elif "wqkv" in name :
8843+ split_dim = 0 if "weight" in name else - 1
8844+ wq , wk , wv = data_torch .chunk (3 , dim = split_dim )
8845+ return [
8846+ (self .map_tensor_name (name .replace ("wqkv" , "wq" )), wq ),
8847+ (self .map_tensor_name (name .replace ("wqkv" , "wk" )), wk ),
8848+ (self .map_tensor_name (name .replace ("wqkv" , "wv" )), wv )
8849+ ]
8850+
8851+ return [(self .map_tensor_name (name ), data_torch )]
8852+
8853+ return [] # skip other tensors
8854+
88168855###### CONVERSION LOGIC ######
88178856
88188857
0 commit comments