Skip to content

Commit 7ea1e82

Browse files
author
liyang
committed
address #16574; fold CLI into mtmd-cli; use ggml_rope_ext + bicubic;switch to 'jinaclip2'; fix converter constants
1 parent 9b17d74 commit 7ea1e82

File tree

11 files changed

+939
-24
lines changed

11 files changed

+939
-24
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,14 +2290,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
22902290
[](common_params & params, int value) {
22912291
params.embd_normalize = value;
22922292
}
2293-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
2293+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
22942294
add_opt(common_arg(
22952295
{"--embd-output-format"}, "FORMAT",
22962296
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
22972297
[](common_params & params, const std::string & value) {
22982298
params.embd_out = value;
22992299
}
2300-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
2300+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
23012301
add_opt(common_arg(
23022302
{"--embd-separator"}, "STRING",
23032303
"separator of embeddings (default \\n) for example \"<#sep#>\"",

convert_hf_to_gguf.py

Lines changed: 260 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5557,7 +5557,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
55575557

55585558
if lora_names := hparams.get("lora_adaptations"):
55595559
self._lora_names = lora_names
5560-
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
5560+
5561+
try:
5562+
text_cfg = hparams.get("text_config", {}) if isinstance(hparams.get("text_config", {}), dict) else {}
5563+
pe_type = (text_cfg.get("position_embedding_type") or hparams.get("position_embedding_type") or "").lower()
5564+
rope_base = text_cfg.get("rotary_emb_base", hparams.get("rotary_emb_base"))
5565+
name_path = (hparams.get("_name_or_path") or "").lower()
5566+
is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path))
5567+
is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx
5568+
if (is_v3) or self._lora_names:
5569+
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
5570+
except Exception:
5571+
pass
55615572

55625573
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
55635574
self._xlmroberta_tokenizer_init()
@@ -6779,6 +6790,254 @@ def set_vocab(self):
67796790
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
67806791

67816792

6793+
@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel")
6794+
class JinaCLIPVisionModel(MmprojModel):
6795+
"""JinaCLIP v2 Vision Encoder Model - handles vision component only"""
6796+
model_arch = gguf.MODEL_ARCH.MMPROJ
6797+
6798+
def __init__(self, *args, **kwargs):
6799+
super().__init__(*args, **kwargs)
6800+
6801+
# Load config for vision encoder
6802+
config_path = self.dir_model / "config.json"
6803+
if not config_path.exists():
6804+
raise FileNotFoundError(
6805+
f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. "
6806+
"Please ensure the original model config is present; default hyperparameter fallbacks are not used."
6807+
)
6808+
with open(config_path, encoding="utf-8") as f:
6809+
self.vision_config = json.load(f)
6810+
6811+
def set_vocab(self):
6812+
# Vision encoder doesn't need vocabulary
6813+
pass
6814+
6815+
def set_gguf_parameters(self):
6816+
cfg = self.vision_config
6817+
6818+
try:
6819+
width = int(cfg["width"]) # channel dim
6820+
head_width = int(cfg["head_width"]) # per-head dim
6821+
layers = int(cfg["layers"]) # block count
6822+
image_size = int(cfg["image_size"]) # input image size
6823+
patch_size = int(cfg["patch_size"]) # patch size
6824+
except KeyError as e:
6825+
raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}")
6826+
6827+
if width % head_width != 0:
6828+
raise ValueError(
6829+
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
6830+
)
6831+
n_head = width // head_width
6832+
6833+
if "mlp_ratio" in cfg:
6834+
n_ff = int(width * float(cfg["mlp_ratio"]))
6835+
elif bool(cfg.get("naive_swiglu", False)):
6836+
n_ff = int((width * 8) // 3)
6837+
else:
6838+
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")
6839+
6840+
self.gguf_writer.add_clip_has_vision_encoder(True)
6841+
proj_dim = int(cfg.get("projection_dim", width))
6842+
self.gguf_writer.add_vision_projection_dim(proj_dim)
6843+
6844+
self.gguf_writer.add_vision_image_size(image_size)
6845+
self.gguf_writer.add_vision_patch_size(patch_size)
6846+
self.gguf_writer.add_vision_embedding_length(width)
6847+
self.gguf_writer.add_vision_block_count(layers)
6848+
self.gguf_writer.add_vision_head_count(n_head)
6849+
self.gguf_writer.add_vision_feed_forward_length(n_ff)
6850+
6851+
self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5)))
6852+
6853+
mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean"))
6854+
std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std"))
6855+
if mean is None or std is None:
6856+
raise KeyError(
6857+
"JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
6858+
)
6859+
self.gguf_writer.add_vision_image_mean(mean)
6860+
self.gguf_writer.add_vision_image_std(std)
6861+
6862+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
6863+
self.gguf_writer.add_vision_use_silu(True)
6864+
6865+
def _strip_vm_prefix(self, name: str) -> str:
6866+
return name[len('vision_model.'):] if name.startswith('vision_model.') else name
6867+
6868+
def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None:
6869+
parts = rest.split('.')
6870+
# layer norms
6871+
if rest.startswith('norm1.'):
6872+
suffix = parts[-1]
6873+
return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)]
6874+
if rest.startswith('norm2.'):
6875+
suffix = parts[-1]
6876+
return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)]
6877+
if rest.startswith('attn.inner_attn_ln.'):
6878+
suffix = parts[-1]
6879+
return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)]
6880+
6881+
# fused qkv
6882+
if rest == 'attn.qkv.weight':
6883+
w = data_torch
6884+
wdim = w.shape[0]
6885+
if wdim % 3 != 0:
6886+
logger.warning('mmproj(jinaclip): unexpected qkv weight shape %s for %s', tuple(w.shape), name)
6887+
d = wdim // 3
6888+
q, k, v = w[0:d, :], w[d:2 * d, :], w[2 * d:, :]
6889+
return [
6890+
(f'v.blk.{layer}.attn_q.weight', q),
6891+
(f'v.blk.{layer}.attn_k.weight', k),
6892+
(f'v.blk.{layer}.attn_v.weight', v),
6893+
]
6894+
if rest == 'attn.qkv.bias':
6895+
b = data_torch
6896+
bdim = b.shape[0]
6897+
if bdim % 3 != 0:
6898+
logger.warning('mmproj(jinaclip): unexpected qkv bias shape %s for %s', tuple(b.shape), name)
6899+
d = bdim // 3
6900+
qb, kb, vb = b[0:d], b[d:2 * d], b[2 * d:]
6901+
return [
6902+
(f'v.blk.{layer}.attn_q.bias', qb),
6903+
(f'v.blk.{layer}.attn_k.bias', kb),
6904+
(f'v.blk.{layer}.attn_v.bias', vb),
6905+
]
6906+
# separate q/v bias (some checkpoints)
6907+
if rest == 'attn.q_bias':
6908+
return [(f'v.blk.{layer}.attn_q.bias', data_torch)]
6909+
if rest == 'attn.v_bias':
6910+
return [(f'v.blk.{layer}.attn_v.bias', data_torch)]
6911+
6912+
# separate projections
6913+
if rest.startswith('attn.q_proj.'):
6914+
suffix = parts[-1]
6915+
return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)]
6916+
if rest.startswith('attn.k_proj.'):
6917+
suffix = parts[-1]
6918+
return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)]
6919+
if rest.startswith('attn.v_proj.'):
6920+
suffix = parts[-1]
6921+
return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)]
6922+
if rest.startswith('attn.proj.'):
6923+
suffix = parts[-1]
6924+
return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)]
6925+
6926+
# MLP
6927+
if rest.startswith('mlp.w1.'):
6928+
suffix = parts[-1]
6929+
return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)]
6930+
if rest.startswith('mlp.w2.'):
6931+
suffix = parts[-1]
6932+
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
6933+
if rest.startswith('mlp.w3.'):
6934+
suffix = parts[-1]
6935+
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
6936+
if rest.startswith('mlp.ffn_ln.'):
6937+
suffix = parts[-1]
6938+
return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)]
6939+
if rest.startswith('mlp.fc1.'):
6940+
suffix = parts[-1]
6941+
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
6942+
if rest.startswith('mlp.fc2.'):
6943+
suffix = parts[-1]
6944+
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
6945+
return None
6946+
6947+
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
6948+
"""Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
6949+
# Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
6950+
if name.startswith('v.') or name.startswith('mm.'):
6951+
return name
6952+
# Try the base mapping first
6953+
try:
6954+
return super().map_tensor_name(name, try_suffixes=try_suffixes)
6955+
except Exception:
6956+
# Fallback to legacy Jina-specific mapper for any remaining edge keys
6957+
if hasattr(self, "_map_jinaclip_tensor_name"):
6958+
mapped = self._map_jinaclip_tensor_name(name) # type: ignore[attr-defined]
6959+
if mapped:
6960+
return mapped
6961+
return name
6962+
6963+
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
6964+
yielded_any = False
6965+
try:
6966+
for name, tensor in super().get_tensors():
6967+
yielded_any = True
6968+
yield name, tensor
6969+
except Exception as e:
6970+
logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e)
6971+
if yielded_any:
6972+
return
6973+
6974+
candidates = [
6975+
self.dir_model / "pytorch_model.bin",
6976+
self.dir_model / "vision_model_weights.bin",
6977+
]
6978+
model_path = next((p for p in candidates if p.exists()), None)
6979+
if model_path is None:
6980+
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")
6981+
try:
6982+
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
6983+
except TypeError:
6984+
state_dict = torch.load(model_path, map_location="cpu")
6985+
6986+
for name, tensor in state_dict.items():
6987+
yield name, tensor
6988+
6989+
def _should_be_f32(self, gguf_name: str) -> bool:
6990+
patterns = (
6991+
".ln1.weight", ".ln1.bias",
6992+
".ln2.weight", ".ln2.bias",
6993+
".attn_ln.weight", ".attn_ln.bias",
6994+
".ffn_norm.weight", ".ffn_norm.bias",
6995+
"v.patch_embd.proj.bias",
6996+
)
6997+
return any(p in gguf_name for p in patterns)
6998+
6999+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7000+
del bid # unused
7001+
7002+
src = name
7003+
if src.startswith('v.') or src.startswith('mm.'):
7004+
return [(src, data_torch)]
7005+
7006+
# Drop 'vision_model.' prefix if present
7007+
src_no_vm = self._strip_vm_prefix(src)
7008+
7009+
# Top-level direct mappings — use gguf constants directly for canonical names
7010+
if src_no_vm == 'cls_token':
7011+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS]
7012+
return [(base, data_torch)]
7013+
if src_no_vm.startswith('patch_embed.proj.'):
7014+
suffix = src_no_vm.split('.')[-1]
7015+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH]
7016+
return [(f'{base}.{suffix}', data_torch)]
7017+
if src_no_vm == 'pos_embed':
7018+
pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight'
7019+
return [(pos_name, data_torch)]
7020+
if src_no_vm.startswith('norm.'):
7021+
suffix = src_no_vm.split('.')[-1]
7022+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM]
7023+
return [(f'{base}.{suffix}', data_torch)]
7024+
7025+
if src_no_vm.startswith('blocks.'):
7026+
parts = src_no_vm.split('.')
7027+
if len(parts) >= 3 and parts[1].isdigit():
7028+
layer = int(parts[1])
7029+
rest = '.'.join(parts[2:])
7030+
mapped = self._map_block_tensor(layer, rest, data_torch, name)
7031+
if mapped is not None:
7032+
return mapped
7033+
7034+
try:
7035+
return [(self.map_tensor_name(name), data_torch)]
7036+
except Exception:
7037+
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
7038+
return []
7039+
7040+
67827041
@ModelBase.register("OpenELMForCausalLM")
67837042
class OpenELMModel(TextModel):
67847043
model_arch = gguf.MODEL_ARCH.OPENELM

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3230,6 +3230,7 @@ class VisionProjectorType:
32303230
QWEN3VL = "qwen3vl_merger"
32313231
ULTRAVOX = "ultravox"
32323232
INTERNVL = "internvl"
3233+
JINACLIP2 = "jinaclip2"
32333234
QWEN2A = "qwen2a" # audio
32343235
QWEN25O = "qwen2.5o" # omni
32353236
VOXTRAL = "voxtral"

tools/mtmd/clip-impl.h

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
4141
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
4242
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
43+
#define KEY_VISION_ROPE_THETA "clip.vision.rope_theta"
4344

4445
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4546
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -69,14 +70,15 @@
6970
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
7071
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
7172
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
73+
#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" // inner attention LayerNorm
7274
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
7375
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
7476
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
7577
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
7678
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
77-
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
78-
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
79-
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
79+
#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s"
80+
#define TN_LN_1 "%s.blk.%d.ln1.%s"
81+
#define TN_LN_2 "%s.blk.%d.ln2.%s"
8082
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
8183
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
8284
#define TN_LN_PRE "%s.pre_ln.%s"
@@ -151,6 +153,7 @@ enum projector_type {
151153
PROJECTOR_TYPE_QWEN2A,
152154
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
153155
PROJECTOR_TYPE_VOXTRAL,
156+
PROJECTOR_TYPE_JINACLIP2, // JinaCLIP v2
154157
PROJECTOR_TYPE_LFM2,
155158
PROJECTOR_TYPE_KIMIVL,
156159
PROJECTOR_TYPE_LIGHTONOCR,
@@ -180,6 +183,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
180183
{ PROJECTOR_TYPE_LFM2, "lfm2"},
181184
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
182185
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
186+
{ PROJECTOR_TYPE_JINACLIP2, "jinaclip2"},
183187
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
184188
{ PROJECTOR_TYPE_JANUS_PRO, "janus_pro"},
185189
};

0 commit comments

Comments
 (0)