@@ -4897,7 +4897,7 @@ def _xlmroberta_set_vocab(self) -> None:
48974897 with open (tokenizer_config_path , "r" , encoding = "utf-8" ) as fp :
48984898 tokenizer_config_json = json .load (fp )
48994899
4900- add_prefix = tokenizer . add_prefix_space
4900+ add_prefix = getattr ( tokenizer , " add_prefix_space" , False )
49014901 remove_whitespaces = tokenizer .clean_up_tokenization_spaces
49024902 precompiled_charsmap = b64decode (tokenizer_json ["normalizer" ]["precompiled_charsmap" ])
49034903
@@ -5183,7 +5183,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
51835183
51845184 if lora_names := hparams .get ("lora_adaptations" ):
51855185 self ._lora_names = lora_names
5186- self .model_arch = gguf .MODEL_ARCH .JINA_BERT_V3
5186+
5187+ try :
5188+ text_cfg = hparams .get ("text_config" , {}) if isinstance (hparams .get ("text_config" , {}), dict ) else {}
5189+ pe_type = (text_cfg .get ("position_embedding_type" ) or hparams .get ("position_embedding_type" ) or "" ).lower ()
5190+ rope_base = text_cfg .get ("rotary_emb_base" , hparams .get ("rotary_emb_base" ))
5191+ name_path = (hparams .get ("_name_or_path" ) or "" ).lower ()
5192+ is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path ))
5193+ is_v3 = (pe_type == "rotary" or rope_base is not None ) and is_vx
5194+ if (is_v3 ) or self ._lora_names :
5195+ self .model_arch = gguf .MODEL_ARCH .JINA_BERT_V3
5196+ except Exception :
5197+ pass
51875198
51885199 super ().__init__ (dir_model , ftype , fname_out , hparams = hparams , ** kwargs )
51895200 self ._xlmroberta_tokenizer_init ()
@@ -6405,6 +6416,254 @@ def set_vocab(self):
64056416 raise NotImplementedError (f'Tokenizer { tokenizer_class } is not supported for JinaBertModel' )
64066417
64076418
6419+ @ModelBase .register ("JinaCLIPVisionModel" , "JinaCLIPModel" )
6420+ class JinaCLIPVisionModel (MmprojModel ):
6421+ """JinaCLIP v2 Vision Encoder Model - handles vision component only"""
6422+ model_arch = gguf .MODEL_ARCH .MMPROJ
6423+
6424+ def __init__ (self , * args , ** kwargs ):
6425+ super ().__init__ (* args , ** kwargs )
6426+
6427+ # Load config for vision encoder
6428+ config_path = self .dir_model / "config.json"
6429+ if not config_path .exists ():
6430+ raise FileNotFoundError (
6431+ f"JinaCLIPVisionModel: missing config.json in { self .dir_model } . "
6432+ "Please ensure the original model config is present; default hyperparameter fallbacks are not used."
6433+ )
6434+ with open (config_path , encoding = "utf-8" ) as f :
6435+ self .vision_config = json .load (f )
6436+
6437+ def set_vocab (self ):
6438+ # Vision encoder doesn't need vocabulary
6439+ pass
6440+
6441+ def set_gguf_parameters (self ):
6442+ cfg = self .vision_config
6443+
6444+ try :
6445+ width = int (cfg ["width" ]) # channel dim
6446+ head_width = int (cfg ["head_width" ]) # per-head dim
6447+ layers = int (cfg ["layers" ]) # block count
6448+ image_size = int (cfg ["image_size" ]) # input image size
6449+ patch_size = int (cfg ["patch_size" ]) # patch size
6450+ except KeyError as e :
6451+ raise KeyError (f"JinaCLIPVisionModel: missing key in config.json: { e } " )
6452+
6453+ if width % head_width != 0 :
6454+ raise ValueError (
6455+ f"JinaCLIPVisionModel: width ({ width } ) not divisible by head_width ({ head_width } )"
6456+ )
6457+ n_head = width // head_width
6458+
6459+ if "mlp_ratio" in cfg :
6460+ n_ff = int (width * float (cfg ["mlp_ratio" ]))
6461+ elif bool (cfg .get ("naive_swiglu" , False )):
6462+ n_ff = int ((width * 8 ) // 3 )
6463+ else :
6464+ raise ValueError ("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json" )
6465+
6466+ self .gguf_writer .add_clip_has_vision_encoder (True )
6467+ proj_dim = int (cfg .get ("projection_dim" , width ))
6468+ self .gguf_writer .add_vision_projection_dim (proj_dim )
6469+
6470+ self .gguf_writer .add_vision_image_size (image_size )
6471+ self .gguf_writer .add_vision_patch_size (patch_size )
6472+ self .gguf_writer .add_vision_embedding_length (width )
6473+ self .gguf_writer .add_vision_block_count (layers )
6474+ self .gguf_writer .add_vision_head_count (n_head )
6475+ self .gguf_writer .add_vision_feed_forward_length (n_ff )
6476+
6477+ self .gguf_writer .add_vision_attention_layernorm_eps (float (cfg .get ("layer_norm_eps" , 1e-5 )))
6478+
6479+ mean = self .preprocessor_config .get ("image_mean" , self .preprocessor_config .get ("mean" ))
6480+ std = self .preprocessor_config .get ("image_std" , self .preprocessor_config .get ("std" ))
6481+ if mean is None or std is None :
6482+ raise KeyError (
6483+ "JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
6484+ )
6485+ self .gguf_writer .add_vision_image_mean (mean )
6486+ self .gguf_writer .add_vision_image_std (std )
6487+
6488+ self .gguf_writer .add_clip_projector_type (gguf .VisionProjectorType .JINACLIP2 )
6489+ self .gguf_writer .add_vision_use_silu (True )
6490+
6491+ def _strip_vm_prefix (self , name : str ) -> str :
6492+ return name [len ('vision_model.' ):] if name .startswith ('vision_model.' ) else name
6493+
6494+ def _map_block_tensor (self , layer : int , rest : str , data_torch : Tensor , name : str ) -> list [tuple [str , Tensor ]] | None :
6495+ parts = rest .split ('.' )
6496+ # layer norms
6497+ if rest .startswith ('norm1.' ):
6498+ suffix = parts [- 1 ]
6499+ return [(f'v.blk.{ layer } .ln1.{ suffix } ' , data_torch )]
6500+ if rest .startswith ('norm2.' ):
6501+ suffix = parts [- 1 ]
6502+ return [(f'v.blk.{ layer } .ln2.{ suffix } ' , data_torch )]
6503+ if rest .startswith ('attn.inner_attn_ln.' ):
6504+ suffix = parts [- 1 ]
6505+ return [(f'v.blk.{ layer } .attn_ln.{ suffix } ' , data_torch )]
6506+
6507+ # fused qkv
6508+ if rest == 'attn.qkv.weight' :
6509+ w = data_torch
6510+ wdim = w .shape [0 ]
6511+ if wdim % 3 != 0 :
6512+ logger .warning ('mmproj(jinaclip): unexpected qkv weight shape %s for %s' , tuple (w .shape ), name )
6513+ d = wdim // 3
6514+ q , k , v = w [0 :d , :], w [d :2 * d , :], w [2 * d :, :]
6515+ return [
6516+ (f'v.blk.{ layer } .attn_q.weight' , q ),
6517+ (f'v.blk.{ layer } .attn_k.weight' , k ),
6518+ (f'v.blk.{ layer } .attn_v.weight' , v ),
6519+ ]
6520+ if rest == 'attn.qkv.bias' :
6521+ b = data_torch
6522+ bdim = b .shape [0 ]
6523+ if bdim % 3 != 0 :
6524+ logger .warning ('mmproj(jinaclip): unexpected qkv bias shape %s for %s' , tuple (b .shape ), name )
6525+ d = bdim // 3
6526+ qb , kb , vb = b [0 :d ], b [d :2 * d ], b [2 * d :]
6527+ return [
6528+ (f'v.blk.{ layer } .attn_q.bias' , qb ),
6529+ (f'v.blk.{ layer } .attn_k.bias' , kb ),
6530+ (f'v.blk.{ layer } .attn_v.bias' , vb ),
6531+ ]
6532+ # separate q/v bias (some checkpoints)
6533+ if rest == 'attn.q_bias' :
6534+ return [(f'v.blk.{ layer } .attn_q.bias' , data_torch )]
6535+ if rest == 'attn.v_bias' :
6536+ return [(f'v.blk.{ layer } .attn_v.bias' , data_torch )]
6537+
6538+ # separate projections
6539+ if rest .startswith ('attn.q_proj.' ):
6540+ suffix = parts [- 1 ]
6541+ return [(f'v.blk.{ layer } .attn_q.{ suffix } ' , data_torch )]
6542+ if rest .startswith ('attn.k_proj.' ):
6543+ suffix = parts [- 1 ]
6544+ return [(f'v.blk.{ layer } .attn_k.{ suffix } ' , data_torch )]
6545+ if rest .startswith ('attn.v_proj.' ):
6546+ suffix = parts [- 1 ]
6547+ return [(f'v.blk.{ layer } .attn_v.{ suffix } ' , data_torch )]
6548+ if rest .startswith ('attn.proj.' ):
6549+ suffix = parts [- 1 ]
6550+ return [(f'v.blk.{ layer } .attn_out.{ suffix } ' , data_torch )]
6551+
6552+ # MLP
6553+ if rest .startswith ('mlp.w1.' ):
6554+ suffix = parts [- 1 ]
6555+ return [(f'v.blk.{ layer } .ffn_gate.{ suffix } ' , data_torch )]
6556+ if rest .startswith ('mlp.w2.' ):
6557+ suffix = parts [- 1 ]
6558+ return [(f'v.blk.{ layer } .ffn_up.{ suffix } ' , data_torch )]
6559+ if rest .startswith ('mlp.w3.' ):
6560+ suffix = parts [- 1 ]
6561+ return [(f'v.blk.{ layer } .ffn_down.{ suffix } ' , data_torch )]
6562+ if rest .startswith ('mlp.ffn_ln.' ):
6563+ suffix = parts [- 1 ]
6564+ return [(f'v.blk.{ layer } .ffn_norm.{ suffix } ' , data_torch )]
6565+ if rest .startswith ('mlp.fc1.' ):
6566+ suffix = parts [- 1 ]
6567+ return [(f'v.blk.{ layer } .ffn_up.{ suffix } ' , data_torch )]
6568+ if rest .startswith ('mlp.fc2.' ):
6569+ suffix = parts [- 1 ]
6570+ return [(f'v.blk.{ layer } .ffn_down.{ suffix } ' , data_torch )]
6571+ return None
6572+
6573+ def map_tensor_name (self , name : str , try_suffixes : Sequence [str ] = (".weight" , ".bias" )) -> str :
6574+ """Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
6575+ # Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
6576+ if name .startswith ('v.' ) or name .startswith ('mm.' ):
6577+ return name
6578+ # Try the base mapping first
6579+ try :
6580+ return super ().map_tensor_name (name , try_suffixes = try_suffixes )
6581+ except Exception :
6582+ # Fallback to legacy Jina-specific mapper for any remaining edge keys
6583+ if hasattr (self , "_map_jinaclip_tensor_name" ):
6584+ mapped = self ._map_jinaclip_tensor_name (name ) # type: ignore[attr-defined]
6585+ if mapped :
6586+ return mapped
6587+ return name
6588+
6589+ def get_tensors (self ) -> Iterator [tuple [str , Tensor ]]:
6590+ yielded_any = False
6591+ try :
6592+ for name , tensor in super ().get_tensors ():
6593+ yielded_any = True
6594+ yield name , tensor
6595+ except Exception as e :
6596+ logger .warning ("mmproj(jinaclip): base get_tensors failed, falling back: %s" , e )
6597+ if yielded_any :
6598+ return
6599+
6600+ candidates = [
6601+ self .dir_model / "pytorch_model.bin" ,
6602+ self .dir_model / "vision_model_weights.bin" ,
6603+ ]
6604+ model_path = next ((p for p in candidates if p .exists ()), None )
6605+ if model_path is None :
6606+ raise FileNotFoundError (f"mmproj(jinaclip): no model weights found in { self .dir_model } " )
6607+ try :
6608+ state_dict = torch .load (model_path , map_location = "cpu" , weights_only = True )
6609+ except TypeError :
6610+ state_dict = torch .load (model_path , map_location = "cpu" )
6611+
6612+ for name , tensor in state_dict .items ():
6613+ yield name , tensor
6614+
6615+ def _should_be_f32 (self , gguf_name : str ) -> bool :
6616+ patterns = (
6617+ ".ln1.weight" , ".ln1.bias" ,
6618+ ".ln2.weight" , ".ln2.bias" ,
6619+ ".attn_ln.weight" , ".attn_ln.bias" ,
6620+ ".ffn_norm.weight" , ".ffn_norm.bias" ,
6621+ "v.patch_embd.proj.bias" ,
6622+ )
6623+ return any (p in gguf_name for p in patterns )
6624+
6625+ def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
6626+ del bid # unused
6627+
6628+ src = name
6629+ if src .startswith ('v.' ) or src .startswith ('mm.' ):
6630+ return [(src , data_torch )]
6631+
6632+ # Drop 'vision_model.' prefix if present
6633+ src_no_vm = self ._strip_vm_prefix (src )
6634+
6635+ # Top-level direct mappings — use gguf constants directly for canonical names
6636+ if src_no_vm == 'cls_token' :
6637+ base = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_CLS ]
6638+ return [(base , data_torch )]
6639+ if src_no_vm .startswith ('patch_embed.proj.' ):
6640+ suffix = src_no_vm .split ('.' )[- 1 ]
6641+ base = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_PATCH ]
6642+ return [(f'{ base } .{ suffix } ' , data_torch )]
6643+ if src_no_vm == 'pos_embed' :
6644+ pos_name = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_ENC_EMBD_POS ] + '.weight'
6645+ return [(pos_name , data_torch )]
6646+ if src_no_vm .startswith ('norm.' ):
6647+ suffix = src_no_vm .split ('.' )[- 1 ]
6648+ base = gguf .TENSOR_NAMES [gguf .MODEL_TENSOR .V_POST_NORM ]
6649+ return [(f'{ base } .{ suffix } ' , data_torch )]
6650+
6651+ if src_no_vm .startswith ('blocks.' ):
6652+ parts = src_no_vm .split ('.' )
6653+ if len (parts ) >= 3 and parts [1 ].isdigit ():
6654+ layer = int (parts [1 ])
6655+ rest = '.' .join (parts [2 :])
6656+ mapped = self ._map_block_tensor (layer , rest , data_torch , name )
6657+ if mapped is not None :
6658+ return mapped
6659+
6660+ try :
6661+ return [(self .map_tensor_name (name ), data_torch )]
6662+ except Exception :
6663+ logger .debug ("mmproj(jinaclip): skip unmapped tensor %s" , name )
6664+ return []
6665+
6666+
64086667@ModelBase .register ("OpenELMForCausalLM" )
64096668class OpenELMModel (TextModel ):
64106669 model_arch = gguf .MODEL_ARCH .OPENELM
0 commit comments