@@ -3159,7 +3159,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
31593159 yield from super ().modify_tensors (data_torch , name , bid )
31603160
31613161
3162- @ModelBase .register ("Ernie4_5_ForCausalLM" )
3162+ @ModelBase .register ("Ernie4_5_ForCausalLM" , "Ernie4_5ForCausalLM" )
31633163class Ernie4_5Model (TextModel ):
31643164 model_arch = gguf .MODEL_ARCH .ERNIE4_5
31653165
@@ -6254,9 +6254,11 @@ def prepare_tensors(self):
62546254 raise ValueError (f"Unprocessed experts: { experts } " )
62556255
62566256
6257- @ModelBase .register ("DeepseekV2ForCausalLM" )
6258- @ModelBase .register ("DeepseekV3ForCausalLM" )
6259- @ModelBase .register ("KimiVLForConditionalGeneration" )
6257+ @ModelBase .register (
6258+ "DeepseekV2ForCausalLM" ,
6259+ "DeepseekV3ForCausalLM" ,
6260+ "KimiVLForConditionalGeneration" ,
6261+ )
62606262class DeepseekV2Model (TextModel ):
62616263 model_arch = gguf .MODEL_ARCH .DEEPSEEK2
62626264
@@ -8507,6 +8509,43 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
85078509 return "mm.2.weight"
85088510 return super ().map_tensor_name (name , try_suffixes )
85098511
8512+
8513+ @ModelBase .register ("KimiVLForConditionalGeneration" )
8514+ class KimiVLModel (MmprojModel ):
8515+ def __init__ (self , * args , ** kwargs ):
8516+ super ().__init__ (* args , ** kwargs )
8517+ assert self .hparams_vision is not None
8518+ self .hparams_vision ["image_size" ] = 64 * 14 # for compatibility
8519+
8520+ def set_gguf_parameters (self ):
8521+ super ().set_gguf_parameters ()
8522+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .KIMIVL )
8523+ self .gguf_writer .add_vision_use_gelu (True )
8524+ self .gguf_writer .add_vision_projector_scale_factor (2 )
8525+ # eps is the same as pytorch's default value
8526+ assert self .hparams_vision is not None
8527+ self .gguf_writer .add_vision_attention_layernorm_eps (self .hparams_vision .get ("layer_norm_eps" , 1e-5 ))
8528+
8529+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
8530+ del bid # unused
8531+ is_vision_tensor = "vision_tower" in name or "multi_modal_projector" in name
8532+
8533+ if is_vision_tensor :
8534+ if "pos_emb.weight" in name :
8535+ data_torch = data_torch .view (data_torch .shape [0 ] * data_torch .shape [1 ], data_torch .shape [2 ])
8536+ elif "wqkv" in name :
8537+ split_dim = 0 if "weight" in name else - 1
8538+ wq , wk , wv = data_torch .chunk (3 , dim = split_dim )
8539+ return [
8540+ (self .map_tensor_name (name .replace ("wqkv" , "wq" )), wq ),
8541+ (self .map_tensor_name (name .replace ("wqkv" , "wk" )), wk ),
8542+ (self .map_tensor_name (name .replace ("wqkv" , "wv" )), wv )
8543+ ]
8544+
8545+ return [(self .map_tensor_name (name ), data_torch )]
8546+
8547+ return [] # skip other tensors
8548+
85108549###### CONVERSION LOGIC ######
85118550
85128551
0 commit comments