Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
295 changes: 294 additions & 1 deletion convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -5026,7 +5026,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,

if lora_names := hparams.get("lora_adaptations"):
self._lora_names = lora_names
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3

# Jina v3 (RoPE) without LoRA should export as jina-bert-v3 to avoid expecting absolute position embeddings
try:
text_cfg = hparams.get("text_config", {}) if isinstance(hparams.get("text_config", {}), dict) else {}
pe_type = (text_cfg.get("position_embedding_type") or hparams.get("position_embedding_type") or "").lower()
rope_base = text_cfg.get("rotary_emb_base", hparams.get("rotary_emb_base"))
name_path = (hparams.get("_name_or_path") or "").lower()
is_v3 = (pe_type == "rotary" or rope_base is not None) and ("jina" in name_path and "v3" in name_path)
if is_v3 and not self._lora_names:
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
Comment on lines -5029 to +5038
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please explain this, first off it breaks jina-embeddings-v3 conversion, secondly jina-clip-v2 looks like it loads jina-embeddings-v3 and uses the retrieval.query LoRA/prompt, but load_trained_adapters set to false suggests it's not applied?
https://huggingface.co/jinaai/jina-clip-v2/blob/main/config.json#L15-L38

except Exception:
pass

super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
self._xlmroberta_tokenizer_init()
Expand Down Expand Up @@ -6248,6 +6259,288 @@ def set_vocab(self):
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')


@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel")
class JinaCLIPVisionModel(MmprojModel):
"""JinaCLIP v2 Vision Encoder Model - handles vision component only"""
model_arch = gguf.MODEL_ARCH.MMPROJ

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

# Load config for vision encoder
config_path = self.dir_model / "config.json"
if config_path.exists():
with open(config_path, encoding="utf-8") as f:
self.vision_config = json.load(f)
else:
# Default JinaCLIP v2 vision configuration
self.vision_config = {
"image_size": 448,
"patch_size": 14,
"hidden_size": 1024,
"num_hidden_layers": 24,
"num_attention_heads": 16,
"intermediate_size": 2731,
"layer_norm_eps": 1e-5,
"projection_dim": 1024
}

def set_vocab(self):
# Vision encoder doesn't need vocabulary
pass

def set_gguf_parameters(self):
# Identification (arch/name is set by writer); mark vision encoder presence
self.gguf_writer.add_bool("clip.has_vision_encoder", True)

# Vision parameters
config = self.vision_config
img_sz = int(config.get("image_size", 448))
patch_sz = int(config.get("patch_size", 14))
n_embd = int(config.get("hidden_size", 1024))
n_layer = int(config.get("num_hidden_layers", 24))
n_head = int(config.get("num_attention_heads", 16))
n_ff = int(config.get("intermediate_size", 2731))
proj_dim = int(config.get("projection_dim", 1024))

self.gguf_writer.add_uint32("clip.vision.image_size", img_sz)
self.gguf_writer.add_uint32("clip.vision.patch_size", patch_sz)
self.gguf_writer.add_uint32("clip.vision.embedding_length", n_embd)
self.gguf_writer.add_uint32("clip.vision.block_count", n_layer)
self.gguf_writer.add_uint32("clip.vision.projection_dim", proj_dim)
self.gguf_writer.add_uint32("clip.vision.feed_forward_length", n_ff)
self.gguf_writer.add_uint32("clip.vision.attention.head_count", n_head)
Comment on lines +6306 to +6312
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We had specific functions and constants to add these metadata keys. Use them instead

# LayerNorm epsilon comes from config (fallback 1e-5)
eps_attn = float(config.get("layer_norm_eps", 1e-5))
self.gguf_writer.add_float32("clip.vision.attention.layer_norm_epsilon", eps_attn)

# Preprocessing defaults
self.gguf_writer.add_array("clip.vision.image_mean", [0.48145466, 0.4578275, 0.40821073])
self.gguf_writer.add_array("clip.vision.image_std", [0.26862954, 0.26130258, 0.27577711])

# Projector type and activation
self.gguf_writer.add_string("clip.projector_type", "jinaclip")
self.gguf_writer.add_vision_use_silu(True)

# RoPE parameter used by vision encoder (prefer config override)
try:
rt = config.get("rope_theta", None)
rope_theta = float(rt) if rt is not None else 10000.0
except Exception:
rope_theta = 10000.0
self.gguf_writer.add_float32("clip.vision.rope_theta", rope_theta)

# Compatibility (mmproj)
self.gguf_writer.add_uint32("mmproj.embedding_length", n_embd)
self.gguf_writer.add_uint32("mmproj.block_count", n_layer)

logger.info(
"mmproj(jinaclip): image_size=%d patch_size=%d n_embd=%d n_layer=%d n_head=%d n_ff=%d proj_dim=%d",
img_sz, patch_sz, n_embd, n_layer, n_head, n_ff, proj_dim
)

# helpers to keep modify_tensors compact and consistent with other models
def _strip_vm_prefix(self, name: str) -> str:
return name[len('vision_model.'):] if name.startswith('vision_model.') else name

def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None:
parts = rest.split('.')
# layer norms
if rest.startswith('norm1.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ln_1.{suffix}', data_torch)]
if rest.startswith('norm2.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ln_2.{suffix}', data_torch)]
if rest.startswith('attn.inner_attn_ln.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)]

# fused qkv
if rest == 'attn.qkv.weight':
w = data_torch
wdim = w.shape[0]
if wdim % 3 != 0:
logger.warning('mmproj(jinaclip): unexpected qkv weight shape %s for %s', tuple(w.shape), name)
d = wdim // 3
q, k, v = w[0:d, :], w[d:2 * d, :], w[2 * d:, :]
return [
(f'v.blk.{layer}.attn_q.weight', q),
(f'v.blk.{layer}.attn_k.weight', k),
(f'v.blk.{layer}.attn_v.weight', v),
]
if rest == 'attn.qkv.bias':
b = data_torch
bdim = b.shape[0]
if bdim % 3 != 0:
logger.warning('mmproj(jinaclip): unexpected qkv bias shape %s for %s', tuple(b.shape), name)
d = bdim // 3
qb, kb, vb = b[0:d], b[d:2 * d], b[2 * d:]
return [
(f'v.blk.{layer}.attn_q.bias', qb),
(f'v.blk.{layer}.attn_k.bias', kb),
(f'v.blk.{layer}.attn_v.bias', vb),
]
# separate q/v bias (some checkpoints)
if rest == 'attn.q_bias':
return [(f'v.blk.{layer}.attn_q.bias', data_torch)]
if rest == 'attn.v_bias':
return [(f'v.blk.{layer}.attn_v.bias', data_torch)]

# separate projections
if rest.startswith('attn.q_proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)]
if rest.startswith('attn.k_proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)]
if rest.startswith('attn.v_proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)]
if rest.startswith('attn.proj.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)]

# MLP
if rest.startswith('mlp.w1.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)]
if rest.startswith('mlp.w2.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
if rest.startswith('mlp.w3.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
if rest.startswith('mlp.ffn_ln.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)]
if rest.startswith('mlp.fc1.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
if rest.startswith('mlp.fc2.'):
suffix = parts[-1]
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
return None

def map_tensor_name(self, name: str) -> str:
"""Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
# Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
if name.startswith('v.') or name.startswith('mm.'):
return name
# Try the base mapping first
try:
return super().map_tensor_name(name)
except Exception:
# Fallback to legacy Jina-specific mapper for any remaining edge keys
mapped = self._map_jinaclip_tensor_name(name)
if mapped:
return mapped
raise

def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
"""Yield tensors for the vision encoder.

Prefer the base implementation (supports sharded/indexed weights). If that fails
or no parts are detected, fall back to a direct single-file load.
"""
# Try base path (indexed/sharded)
try:
if getattr(self, "part_names", None):
for name, tensor in super().get_tensors():
yield name, tensor
return
except Exception as e:
logger.warning("mmproj(jinaclip): base get_tensors failed, falling back to direct load: %s", e)

# Fallback: direct single-file load
import torch
candidates = [
self.dir_model / "pytorch_model.bin",
self.dir_model / "model.safetensors",
self.dir_model / "pytorch_model.safetensors",
self.dir_model / "vision_model_weights.bin",
]
model_path = next((p for p in candidates if p.exists()), None)
if model_path is None:
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")

logger.info("mmproj(jinaclip): loading weights from %s", model_path)
if model_path.suffix == ".bin":
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
else:
from safetensors.torch import load_file
state_dict = load_file(str(model_path))

count = 0
for name, tensor in state_dict.items():
# yield raw names; modify_tensors will normalize & map
yield name, tensor
count += 1

logger.info("mmproj(jinaclip): yielded %d raw tensors", count)

def _should_be_f32(self, gguf_name: str) -> bool:
"""Return True if tensor should be stored as F32 to avoid type mismatches in C++ runtime.

Keep the list minimal: LayerNorm weights/bias are the common source of
binary-op dtype issues; patch embedding bias is also safer as F32.
"""
patterns = (
".ln_1.weight", ".ln_1.bias",
".ln_2.weight", ".ln_2.bias",
".attn_ln.weight", ".attn_ln.bias",
".ffn_norm.weight", ".ffn_norm.bias",
"v.patch_embd.proj.bias",
)
return any(p in gguf_name for p in patterns)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
"""Normalize JinaCLIP vision tensor names to base-friendly patterns, with Jina-specific exceptions.

- Emit Jina-specific targets directly for: patch/proj, pos_embed, inner-attn LN, SwiGLU FFN names.
- If fused QKV is encountered, split into Q/K/V.
- For standard pieces (norm1/norm2, q/k/v/out), map to v.blk.{i}.* targets.
"""
del bid # unused

src = name
# Already in target form
if src.startswith('v.') or src.startswith('mm.'):
return [(src, data_torch)]

# Drop 'vision_model.' prefix if present
src_no_vm = self._strip_vm_prefix(src)

# Top-level direct mappings
if src_no_vm == 'cls_token':
return [('v.cls_token', data_torch)]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use proper mapping instead

if src_no_vm.startswith('patch_embed.proj.'):
suffix = src_no_vm.split('.')[-1]
return [(f'v.patch_embd.proj.{suffix}', data_torch)]
if src_no_vm == 'pos_embed':
return [('v.pos_embd', data_torch)]
if src_no_vm.startswith('norm.'):
suffix = src_no_vm.split('.')[-1]
return [(f'v.post_ln.{suffix}', data_torch)]

# Transformer blocks
if src_no_vm.startswith('blocks.'):
parts = src_no_vm.split('.')
if len(parts) >= 3 and parts[1].isdigit():
layer = int(parts[1])
rest = '.'.join(parts[2:])
mapped = self._map_block_tensor(layer, rest, data_torch, name)
if mapped is not None:
return mapped

# Fallback: try base mapping; if fails, drop tensor
try:
return [(self.map_tensor_name(name), data_torch)]
except Exception:
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
return []


@ModelBase.register("OpenELMForCausalLM")
class OpenELMModel(TextModel):
model_arch = gguf.MODEL_ARCH.OPENELM
Expand Down
5 changes: 5 additions & 0 deletions tools/mtmd/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,8 @@ if(LLAMA_TOOLS_INSTALL)
endif()
target_link_libraries (${TARGET} PRIVATE common mtmd Threads::Threads)
target_compile_features(${TARGET} PRIVATE cxx_std_17)

# JinaCLIP CLI (align style with other targets above)
set(TARGET llama-jinaclip-cli)
add_executable (${TARGET} jinaclip-cli.cpp)
target_link_libraries (${TARGET} PRIVATE common mtmd Threads::Threads)
Comment on lines +63 to +67
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should try to merge this with mtmd-cli to avoid the "fragmentation" trap of the old llava-cli binary

Copy link
Contributor Author

@pockers21 pockers21 Oct 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree to merge into llama-mtmd-cli to avoid adding another standalone CLI.
However, Jina embedding could support text and vision embedding with individual model, differs from the existing mtmd-cli workflow which needs text and vision model in the same time.
I will add a Jina‑specific path in mtmd-cli that supports running with only --mmproj + --image.

Loading