Skip to content

Commit 0edcd41

Browse files
max-krasnyanskyliyang
authored andcommitted
cpu: introduce chunking for flash attention (ggml-org#16829)
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop on top that handles the chunks.
1 parent d2d931f commit 0edcd41

File tree

11 files changed

+972
-72
lines changed

11 files changed

+972
-72
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3245,14 +3245,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32453245
[](common_params & params, int value) {
32463246
params.embd_normalize = value;
32473247
}
3248-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
3248+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
32493249
add_opt(common_arg(
32503250
{"--embd-output-format"}, "FORMAT",
32513251
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
32523252
[](common_params & params, const std::string & value) {
32533253
params.embd_out = value;
32543254
}
3255-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
3255+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
32563256
add_opt(common_arg(
32573257
{"--embd-separator"}, "STRING",
32583258
"separator of embeddings (default \\n) for example \"<#sep#>\"",

convert_hf_to_gguf.py

Lines changed: 261 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5114,7 +5114,7 @@ def _xlmroberta_set_vocab(self) -> None:
51145114
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
51155115
tokenizer_config_json = json.load(fp)
51165116

5117-
add_prefix = tokenizer.add_prefix_space
5117+
add_prefix = getattr(tokenizer, "add_prefix_space", False)
51185118
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
51195119
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
51205120

@@ -5400,7 +5400,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
54005400

54015401
if lora_names := hparams.get("lora_adaptations"):
54025402
self._lora_names = lora_names
5403-
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
5403+
5404+
try:
5405+
text_cfg = hparams.get("text_config", {}) if isinstance(hparams.get("text_config", {}), dict) else {}
5406+
pe_type = (text_cfg.get("position_embedding_type") or hparams.get("position_embedding_type") or "").lower()
5407+
rope_base = text_cfg.get("rotary_emb_base", hparams.get("rotary_emb_base"))
5408+
name_path = (hparams.get("_name_or_path") or "").lower()
5409+
is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path))
5410+
is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx
5411+
if (is_v3) or self._lora_names:
5412+
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
5413+
except Exception:
5414+
pass
54045415

54055416
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
54065417
self._xlmroberta_tokenizer_init()
@@ -6622,6 +6633,254 @@ def set_vocab(self):
66226633
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
66236634

66246635

6636+
@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel")
6637+
class JinaCLIPVisionModel(MmprojModel):
6638+
"""JinaCLIP v2 Vision Encoder Model - handles vision component only"""
6639+
model_arch = gguf.MODEL_ARCH.MMPROJ
6640+
6641+
def __init__(self, *args, **kwargs):
6642+
super().__init__(*args, **kwargs)
6643+
6644+
# Load config for vision encoder
6645+
config_path = self.dir_model / "config.json"
6646+
if not config_path.exists():
6647+
raise FileNotFoundError(
6648+
f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. "
6649+
"Please ensure the original model config is present; default hyperparameter fallbacks are not used."
6650+
)
6651+
with open(config_path, encoding="utf-8") as f:
6652+
self.vision_config = json.load(f)
6653+
6654+
def set_vocab(self):
6655+
# Vision encoder doesn't need vocabulary
6656+
pass
6657+
6658+
def set_gguf_parameters(self):
6659+
cfg = self.vision_config
6660+
6661+
try:
6662+
width = int(cfg["width"]) # channel dim
6663+
head_width = int(cfg["head_width"]) # per-head dim
6664+
layers = int(cfg["layers"]) # block count
6665+
image_size = int(cfg["image_size"]) # input image size
6666+
patch_size = int(cfg["patch_size"]) # patch size
6667+
except KeyError as e:
6668+
raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}")
6669+
6670+
if width % head_width != 0:
6671+
raise ValueError(
6672+
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
6673+
)
6674+
n_head = width // head_width
6675+
6676+
if "mlp_ratio" in cfg:
6677+
n_ff = int(width * float(cfg["mlp_ratio"]))
6678+
elif bool(cfg.get("naive_swiglu", False)):
6679+
n_ff = int((width * 8) // 3)
6680+
else:
6681+
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")
6682+
6683+
self.gguf_writer.add_clip_has_vision_encoder(True)
6684+
proj_dim = int(cfg.get("projection_dim", width))
6685+
self.gguf_writer.add_vision_projection_dim(proj_dim)
6686+
6687+
self.gguf_writer.add_vision_image_size(image_size)
6688+
self.gguf_writer.add_vision_patch_size(patch_size)
6689+
self.gguf_writer.add_vision_embedding_length(width)
6690+
self.gguf_writer.add_vision_block_count(layers)
6691+
self.gguf_writer.add_vision_head_count(n_head)
6692+
self.gguf_writer.add_vision_feed_forward_length(n_ff)
6693+
6694+
self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5)))
6695+
6696+
mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean"))
6697+
std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std"))
6698+
if mean is None or std is None:
6699+
raise KeyError(
6700+
"JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
6701+
)
6702+
self.gguf_writer.add_vision_image_mean(mean)
6703+
self.gguf_writer.add_vision_image_std(std)
6704+
6705+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
6706+
self.gguf_writer.add_vision_use_silu(True)
6707+
6708+
def _strip_vm_prefix(self, name: str) -> str:
6709+
return name[len('vision_model.'):] if name.startswith('vision_model.') else name
6710+
6711+
def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None:
6712+
parts = rest.split('.')
6713+
# layer norms
6714+
if rest.startswith('norm1.'):
6715+
suffix = parts[-1]
6716+
return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)]
6717+
if rest.startswith('norm2.'):
6718+
suffix = parts[-1]
6719+
return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)]
6720+
if rest.startswith('attn.inner_attn_ln.'):
6721+
suffix = parts[-1]
6722+
return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)]
6723+
6724+
# fused qkv
6725+
if rest == 'attn.qkv.weight':
6726+
w = data_torch
6727+
wdim = w.shape[0]
6728+
if wdim % 3 != 0:
6729+
logger.warning('mmproj(jinaclip): unexpected qkv weight shape %s for %s', tuple(w.shape), name)
6730+
d = wdim // 3
6731+
q, k, v = w[0:d, :], w[d:2 * d, :], w[2 * d:, :]
6732+
return [
6733+
(f'v.blk.{layer}.attn_q.weight', q),
6734+
(f'v.blk.{layer}.attn_k.weight', k),
6735+
(f'v.blk.{layer}.attn_v.weight', v),
6736+
]
6737+
if rest == 'attn.qkv.bias':
6738+
b = data_torch
6739+
bdim = b.shape[0]
6740+
if bdim % 3 != 0:
6741+
logger.warning('mmproj(jinaclip): unexpected qkv bias shape %s for %s', tuple(b.shape), name)
6742+
d = bdim // 3
6743+
qb, kb, vb = b[0:d], b[d:2 * d], b[2 * d:]
6744+
return [
6745+
(f'v.blk.{layer}.attn_q.bias', qb),
6746+
(f'v.blk.{layer}.attn_k.bias', kb),
6747+
(f'v.blk.{layer}.attn_v.bias', vb),
6748+
]
6749+
# separate q/v bias (some checkpoints)
6750+
if rest == 'attn.q_bias':
6751+
return [(f'v.blk.{layer}.attn_q.bias', data_torch)]
6752+
if rest == 'attn.v_bias':
6753+
return [(f'v.blk.{layer}.attn_v.bias', data_torch)]
6754+
6755+
# separate projections
6756+
if rest.startswith('attn.q_proj.'):
6757+
suffix = parts[-1]
6758+
return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)]
6759+
if rest.startswith('attn.k_proj.'):
6760+
suffix = parts[-1]
6761+
return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)]
6762+
if rest.startswith('attn.v_proj.'):
6763+
suffix = parts[-1]
6764+
return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)]
6765+
if rest.startswith('attn.proj.'):
6766+
suffix = parts[-1]
6767+
return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)]
6768+
6769+
# MLP
6770+
if rest.startswith('mlp.w1.'):
6771+
suffix = parts[-1]
6772+
return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)]
6773+
if rest.startswith('mlp.w2.'):
6774+
suffix = parts[-1]
6775+
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
6776+
if rest.startswith('mlp.w3.'):
6777+
suffix = parts[-1]
6778+
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
6779+
if rest.startswith('mlp.ffn_ln.'):
6780+
suffix = parts[-1]
6781+
return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)]
6782+
if rest.startswith('mlp.fc1.'):
6783+
suffix = parts[-1]
6784+
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
6785+
if rest.startswith('mlp.fc2.'):
6786+
suffix = parts[-1]
6787+
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
6788+
return None
6789+
6790+
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
6791+
"""Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
6792+
# Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
6793+
if name.startswith('v.') or name.startswith('mm.'):
6794+
return name
6795+
# Try the base mapping first
6796+
try:
6797+
return super().map_tensor_name(name, try_suffixes=try_suffixes)
6798+
except Exception:
6799+
# Fallback to legacy Jina-specific mapper for any remaining edge keys
6800+
if hasattr(self, "_map_jinaclip_tensor_name"):
6801+
mapped = self._map_jinaclip_tensor_name(name) # type: ignore[attr-defined]
6802+
if mapped:
6803+
return mapped
6804+
return name
6805+
6806+
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
6807+
yielded_any = False
6808+
try:
6809+
for name, tensor in super().get_tensors():
6810+
yielded_any = True
6811+
yield name, tensor
6812+
except Exception as e:
6813+
logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e)
6814+
if yielded_any:
6815+
return
6816+
6817+
candidates = [
6818+
self.dir_model / "pytorch_model.bin",
6819+
self.dir_model / "vision_model_weights.bin",
6820+
]
6821+
model_path = next((p for p in candidates if p.exists()), None)
6822+
if model_path is None:
6823+
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")
6824+
try:
6825+
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
6826+
except TypeError:
6827+
state_dict = torch.load(model_path, map_location="cpu")
6828+
6829+
for name, tensor in state_dict.items():
6830+
yield name, tensor
6831+
6832+
def _should_be_f32(self, gguf_name: str) -> bool:
6833+
patterns = (
6834+
".ln1.weight", ".ln1.bias",
6835+
".ln2.weight", ".ln2.bias",
6836+
".attn_ln.weight", ".attn_ln.bias",
6837+
".ffn_norm.weight", ".ffn_norm.bias",
6838+
"v.patch_embd.proj.bias",
6839+
)
6840+
return any(p in gguf_name for p in patterns)
6841+
6842+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6843+
del bid # unused
6844+
6845+
src = name
6846+
if src.startswith('v.') or src.startswith('mm.'):
6847+
return [(src, data_torch)]
6848+
6849+
# Drop 'vision_model.' prefix if present
6850+
src_no_vm = self._strip_vm_prefix(src)
6851+
6852+
# Top-level direct mappings — use gguf constants directly for canonical names
6853+
if src_no_vm == 'cls_token':
6854+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS]
6855+
return [(base, data_torch)]
6856+
if src_no_vm.startswith('patch_embed.proj.'):
6857+
suffix = src_no_vm.split('.')[-1]
6858+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH]
6859+
return [(f'{base}.{suffix}', data_torch)]
6860+
if src_no_vm == 'pos_embed':
6861+
pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight'
6862+
return [(pos_name, data_torch)]
6863+
if src_no_vm.startswith('norm.'):
6864+
suffix = src_no_vm.split('.')[-1]
6865+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM]
6866+
return [(f'{base}.{suffix}', data_torch)]
6867+
6868+
if src_no_vm.startswith('blocks.'):
6869+
parts = src_no_vm.split('.')
6870+
if len(parts) >= 3 and parts[1].isdigit():
6871+
layer = int(parts[1])
6872+
rest = '.'.join(parts[2:])
6873+
mapped = self._map_block_tensor(layer, rest, data_torch, name)
6874+
if mapped is not None:
6875+
return mapped
6876+
6877+
try:
6878+
return [(self.map_tensor_name(name), data_torch)]
6879+
except Exception:
6880+
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
6881+
return []
6882+
6883+
66256884
@ModelBase.register("OpenELMForCausalLM")
66266885
class OpenELMModel(TextModel):
66276886
model_arch = gguf.MODEL_ARCH.OPENELM

gguf-py/gguf/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3159,6 +3159,7 @@ class VisionProjectorType:
31593159
QWEN3VL = "qwen3vl_merger"
31603160
ULTRAVOX = "ultravox"
31613161
INTERNVL = "internvl"
3162+
JINACLIP2 = "jinaclip2"
31623163
QWEN2A = "qwen2a" # audio
31633164
QWEN25O = "qwen2.5o" # omni
31643165
VOXTRAL = "voxtral"

tools/mtmd/clip-impl.h

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor"
4141
#define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size"
4242
#define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers"
43+
#define KEY_VISION_ROPE_THETA "clip.vision.rope_theta"
4344

4445
#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type"
4546
#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints"
@@ -69,14 +70,15 @@
6970
#define TN_ATTN_Q "%s.blk.%d.attn_q.%s"
7071
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
7172
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
73+
#define TN_ATTN_LN "%s.blk.%d.attn_ln.%s" // inner attention LayerNorm
7274
#define TN_ATTN_K_NORM "%s.blk.%d.attn_k_norm.%s"
7375
#define TN_ATTN_Q_NORM "%s.blk.%d.attn_q_norm.%s"
7476
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
7577
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
7678
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
77-
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
78-
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
79-
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
79+
#define TN_FFN_NORM "%s.blk.%d.ffn_norm.%s"
80+
#define TN_LN_1 "%s.blk.%d.ln1.%s" // layer norm
81+
#define TN_LN_2 "%s.blk.%d.ln2.%s" // layer norm
8082
#define TN_LS_1 "%s.blk.%d.ls1.%s" // layer scale
8183
#define TN_LS_2 "%s.blk.%d.ls2.%s" // layer scale
8284
#define TN_LN_PRE "%s.pre_ln.%s"
@@ -151,6 +153,7 @@ enum projector_type {
151153
PROJECTOR_TYPE_QWEN2A,
152154
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
153155
PROJECTOR_TYPE_VOXTRAL,
156+
PROJECTOR_TYPE_JINACLIP2, // JinaCLIP v2
154157
PROJECTOR_TYPE_LFM2,
155158
PROJECTOR_TYPE_KIMIVL,
156159
PROJECTOR_TYPE_LIGHTONOCR,
@@ -159,27 +162,28 @@ enum projector_type {
159162
};
160163

161164
static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
162-
{ PROJECTOR_TYPE_MLP, "mlp" },
163-
{ PROJECTOR_TYPE_LDP, "ldp" },
164-
{ PROJECTOR_TYPE_LDPV2, "ldpv2"},
165-
{ PROJECTOR_TYPE_MINICPMV, "resampler"},
166-
{ PROJECTOR_TYPE_GLM_EDGE, "adapter"},
167-
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger"},
168-
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger"},
169-
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger"},
170-
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
171-
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
172-
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
173-
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox"},
174-
{ PROJECTOR_TYPE_INTERNVL, "internvl"},
175-
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
176-
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
177-
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
178-
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
179-
{ PROJECTOR_TYPE_LFM2, "lfm2"},
180-
{ PROJECTOR_TYPE_KIMIVL, "kimivl"},
181-
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr"},
182-
{ PROJECTOR_TYPE_COGVLM, "cogvlm"},
165+
{ PROJECTOR_TYPE_MLP, "mlp" },
166+
{ PROJECTOR_TYPE_LDP, "ldp" },
167+
{ PROJECTOR_TYPE_LDPV2, "ldpv2" },
168+
{ PROJECTOR_TYPE_MINICPMV, "resampler" },
169+
{ PROJECTOR_TYPE_GLM_EDGE, "adapter" },
170+
{ PROJECTOR_TYPE_QWEN2VL, "qwen2vl_merger" },
171+
{ PROJECTOR_TYPE_QWEN25VL, "qwen2.5vl_merger" },
172+
{ PROJECTOR_TYPE_QWEN3VL, "qwen3vl_merger" },
173+
{ PROJECTOR_TYPE_GEMMA3, "gemma3" },
174+
{ PROJECTOR_TYPE_IDEFICS3, "idefics3" },
175+
{ PROJECTOR_TYPE_PIXTRAL, "pixtral" },
176+
{ PROJECTOR_TYPE_ULTRAVOX, "ultravox" },
177+
{ PROJECTOR_TYPE_INTERNVL, "internvl" },
178+
{ PROJECTOR_TYPE_LLAMA4, "llama4" },
179+
{ PROJECTOR_TYPE_QWEN2A, "qwen2a" },
180+
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o" },
181+
{ PROJECTOR_TYPE_VOXTRAL, "voxtral" },
182+
{ PROJECTOR_TYPE_LFM2, "lfm2" },
183+
{ PROJECTOR_TYPE_KIMIVL, "kimivl" },
184+
{ PROJECTOR_TYPE_LIGHTONOCR,"lightonocr" },
185+
{ PROJECTOR_TYPE_JINACLIP2,"jinaclip2" },
186+
{ PROJECTOR_TYPE_COGVLM, "cogvlm" },
183187
};
184188

185189
static projector_type clip_projector_type_from_string(const std::string & str) {

0 commit comments

Comments
 (0)