Skip to content

Commit 5b3fa94

Browse files
max-krasnyanskyliyang
authored andcommitted
cpu: introduce chunking for flash attention (ggml-org#16829)
Factor out the core FA loop into flash_atten_f16_one_chunk and add an outter loop on top that handles the chunks.
1 parent bacddc0 commit 5b3fa94

File tree

12 files changed

+1072
-101
lines changed

12 files changed

+1072
-101
lines changed

common/arg.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3245,14 +3245,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
32453245
[](common_params & params, int value) {
32463246
params.embd_normalize = value;
32473247
}
3248-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
3248+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
32493249
add_opt(common_arg(
32503250
{"--embd-output-format"}, "FORMAT",
32513251
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
32523252
[](common_params & params, const std::string & value) {
32533253
params.embd_out = value;
32543254
}
3255-
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
3255+
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_MTMD}));
32563256
add_opt(common_arg(
32573257
{"--embd-separator"}, "STRING",
32583258
"separator of embeddings (default \\n) for example \"<#sep#>\"",

convert_hf_to_gguf.py

Lines changed: 261 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4897,7 +4897,7 @@ def _xlmroberta_set_vocab(self) -> None:
48974897
with open(tokenizer_config_path, "r", encoding="utf-8") as fp:
48984898
tokenizer_config_json = json.load(fp)
48994899

4900-
add_prefix = tokenizer.add_prefix_space
4900+
add_prefix = getattr(tokenizer, "add_prefix_space", False)
49014901
remove_whitespaces = tokenizer.clean_up_tokenization_spaces
49024902
precompiled_charsmap = b64decode(tokenizer_json["normalizer"]["precompiled_charsmap"])
49034903

@@ -5183,7 +5183,18 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
51835183

51845184
if lora_names := hparams.get("lora_adaptations"):
51855185
self._lora_names = lora_names
5186-
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
5186+
5187+
try:
5188+
text_cfg = hparams.get("text_config", {}) if isinstance(hparams.get("text_config", {}), dict) else {}
5189+
pe_type = (text_cfg.get("position_embedding_type") or hparams.get("position_embedding_type") or "").lower()
5190+
rope_base = text_cfg.get("rotary_emb_base", hparams.get("rotary_emb_base"))
5191+
name_path = (hparams.get("_name_or_path") or "").lower()
5192+
is_vx = ("jina" in name_path and ("v2" in name_path or "v3" in name_path))
5193+
is_v3 = (pe_type == "rotary" or rope_base is not None) and is_vx
5194+
if (is_v3) or self._lora_names:
5195+
self.model_arch = gguf.MODEL_ARCH.JINA_BERT_V3
5196+
except Exception:
5197+
pass
51875198

51885199
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
51895200
self._xlmroberta_tokenizer_init()
@@ -6405,6 +6416,254 @@ def set_vocab(self):
64056416
raise NotImplementedError(f'Tokenizer {tokenizer_class} is not supported for JinaBertModel')
64066417

64076418

6419+
@ModelBase.register("JinaCLIPVisionModel", "JinaCLIPModel")
6420+
class JinaCLIPVisionModel(MmprojModel):
6421+
"""JinaCLIP v2 Vision Encoder Model - handles vision component only"""
6422+
model_arch = gguf.MODEL_ARCH.MMPROJ
6423+
6424+
def __init__(self, *args, **kwargs):
6425+
super().__init__(*args, **kwargs)
6426+
6427+
# Load config for vision encoder
6428+
config_path = self.dir_model / "config.json"
6429+
if not config_path.exists():
6430+
raise FileNotFoundError(
6431+
f"JinaCLIPVisionModel: missing config.json in {self.dir_model}. "
6432+
"Please ensure the original model config is present; default hyperparameter fallbacks are not used."
6433+
)
6434+
with open(config_path, encoding="utf-8") as f:
6435+
self.vision_config = json.load(f)
6436+
6437+
def set_vocab(self):
6438+
# Vision encoder doesn't need vocabulary
6439+
pass
6440+
6441+
def set_gguf_parameters(self):
6442+
cfg = self.vision_config
6443+
6444+
try:
6445+
width = int(cfg["width"]) # channel dim
6446+
head_width = int(cfg["head_width"]) # per-head dim
6447+
layers = int(cfg["layers"]) # block count
6448+
image_size = int(cfg["image_size"]) # input image size
6449+
patch_size = int(cfg["patch_size"]) # patch size
6450+
except KeyError as e:
6451+
raise KeyError(f"JinaCLIPVisionModel: missing key in config.json: {e}")
6452+
6453+
if width % head_width != 0:
6454+
raise ValueError(
6455+
f"JinaCLIPVisionModel: width ({width}) not divisible by head_width ({head_width})"
6456+
)
6457+
n_head = width // head_width
6458+
6459+
if "mlp_ratio" in cfg:
6460+
n_ff = int(width * float(cfg["mlp_ratio"]))
6461+
elif bool(cfg.get("naive_swiglu", False)):
6462+
n_ff = int((width * 8) // 3)
6463+
else:
6464+
raise ValueError("JinaCLIPVisionModel: unable to infer FFN size; please provide 'mlp_ratio' or set 'naive_swiglu' in config.json")
6465+
6466+
self.gguf_writer.add_clip_has_vision_encoder(True)
6467+
proj_dim = int(cfg.get("projection_dim", width))
6468+
self.gguf_writer.add_vision_projection_dim(proj_dim)
6469+
6470+
self.gguf_writer.add_vision_image_size(image_size)
6471+
self.gguf_writer.add_vision_patch_size(patch_size)
6472+
self.gguf_writer.add_vision_embedding_length(width)
6473+
self.gguf_writer.add_vision_block_count(layers)
6474+
self.gguf_writer.add_vision_head_count(n_head)
6475+
self.gguf_writer.add_vision_feed_forward_length(n_ff)
6476+
6477+
self.gguf_writer.add_vision_attention_layernorm_eps(float(cfg.get("layer_norm_eps", 1e-5)))
6478+
6479+
mean = self.preprocessor_config.get("image_mean", self.preprocessor_config.get("mean"))
6480+
std = self.preprocessor_config.get("image_std", self.preprocessor_config.get("std"))
6481+
if mean is None or std is None:
6482+
raise KeyError(
6483+
"JinaCLIPVisionModel: preprocessor_config missing image mean/std (expected keys: 'image_mean'/'image_std' or 'mean'/'std')"
6484+
)
6485+
self.gguf_writer.add_vision_image_mean(mean)
6486+
self.gguf_writer.add_vision_image_std(std)
6487+
6488+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.JINACLIP2)
6489+
self.gguf_writer.add_vision_use_silu(True)
6490+
6491+
def _strip_vm_prefix(self, name: str) -> str:
6492+
return name[len('vision_model.'):] if name.startswith('vision_model.') else name
6493+
6494+
def _map_block_tensor(self, layer: int, rest: str, data_torch: Tensor, name: str) -> list[tuple[str, Tensor]] | None:
6495+
parts = rest.split('.')
6496+
# layer norms
6497+
if rest.startswith('norm1.'):
6498+
suffix = parts[-1]
6499+
return [(f'v.blk.{layer}.ln1.{suffix}', data_torch)]
6500+
if rest.startswith('norm2.'):
6501+
suffix = parts[-1]
6502+
return [(f'v.blk.{layer}.ln2.{suffix}', data_torch)]
6503+
if rest.startswith('attn.inner_attn_ln.'):
6504+
suffix = parts[-1]
6505+
return [(f'v.blk.{layer}.attn_ln.{suffix}', data_torch)]
6506+
6507+
# fused qkv
6508+
if rest == 'attn.qkv.weight':
6509+
w = data_torch
6510+
wdim = w.shape[0]
6511+
if wdim % 3 != 0:
6512+
logger.warning('mmproj(jinaclip): unexpected qkv weight shape %s for %s', tuple(w.shape), name)
6513+
d = wdim // 3
6514+
q, k, v = w[0:d, :], w[d:2 * d, :], w[2 * d:, :]
6515+
return [
6516+
(f'v.blk.{layer}.attn_q.weight', q),
6517+
(f'v.blk.{layer}.attn_k.weight', k),
6518+
(f'v.blk.{layer}.attn_v.weight', v),
6519+
]
6520+
if rest == 'attn.qkv.bias':
6521+
b = data_torch
6522+
bdim = b.shape[0]
6523+
if bdim % 3 != 0:
6524+
logger.warning('mmproj(jinaclip): unexpected qkv bias shape %s for %s', tuple(b.shape), name)
6525+
d = bdim // 3
6526+
qb, kb, vb = b[0:d], b[d:2 * d], b[2 * d:]
6527+
return [
6528+
(f'v.blk.{layer}.attn_q.bias', qb),
6529+
(f'v.blk.{layer}.attn_k.bias', kb),
6530+
(f'v.blk.{layer}.attn_v.bias', vb),
6531+
]
6532+
# separate q/v bias (some checkpoints)
6533+
if rest == 'attn.q_bias':
6534+
return [(f'v.blk.{layer}.attn_q.bias', data_torch)]
6535+
if rest == 'attn.v_bias':
6536+
return [(f'v.blk.{layer}.attn_v.bias', data_torch)]
6537+
6538+
# separate projections
6539+
if rest.startswith('attn.q_proj.'):
6540+
suffix = parts[-1]
6541+
return [(f'v.blk.{layer}.attn_q.{suffix}', data_torch)]
6542+
if rest.startswith('attn.k_proj.'):
6543+
suffix = parts[-1]
6544+
return [(f'v.blk.{layer}.attn_k.{suffix}', data_torch)]
6545+
if rest.startswith('attn.v_proj.'):
6546+
suffix = parts[-1]
6547+
return [(f'v.blk.{layer}.attn_v.{suffix}', data_torch)]
6548+
if rest.startswith('attn.proj.'):
6549+
suffix = parts[-1]
6550+
return [(f'v.blk.{layer}.attn_out.{suffix}', data_torch)]
6551+
6552+
# MLP
6553+
if rest.startswith('mlp.w1.'):
6554+
suffix = parts[-1]
6555+
return [(f'v.blk.{layer}.ffn_gate.{suffix}', data_torch)]
6556+
if rest.startswith('mlp.w2.'):
6557+
suffix = parts[-1]
6558+
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
6559+
if rest.startswith('mlp.w3.'):
6560+
suffix = parts[-1]
6561+
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
6562+
if rest.startswith('mlp.ffn_ln.'):
6563+
suffix = parts[-1]
6564+
return [(f'v.blk.{layer}.ffn_norm.{suffix}', data_torch)]
6565+
if rest.startswith('mlp.fc1.'):
6566+
suffix = parts[-1]
6567+
return [(f'v.blk.{layer}.ffn_up.{suffix}', data_torch)]
6568+
if rest.startswith('mlp.fc2.'):
6569+
suffix = parts[-1]
6570+
return [(f'v.blk.{layer}.ffn_down.{suffix}', data_torch)]
6571+
return None
6572+
6573+
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
6574+
"""Prefer base table-driven mapping; keep Jina-specific targets if already mapped; fallback to legacy mapper."""
6575+
# Already a GGUF target name (e.g., "v.*" or "mm.*"): return as-is
6576+
if name.startswith('v.') or name.startswith('mm.'):
6577+
return name
6578+
# Try the base mapping first
6579+
try:
6580+
return super().map_tensor_name(name, try_suffixes=try_suffixes)
6581+
except Exception:
6582+
# Fallback to legacy Jina-specific mapper for any remaining edge keys
6583+
if hasattr(self, "_map_jinaclip_tensor_name"):
6584+
mapped = self._map_jinaclip_tensor_name(name) # type: ignore[attr-defined]
6585+
if mapped:
6586+
return mapped
6587+
return name
6588+
6589+
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
6590+
yielded_any = False
6591+
try:
6592+
for name, tensor in super().get_tensors():
6593+
yielded_any = True
6594+
yield name, tensor
6595+
except Exception as e:
6596+
logger.warning("mmproj(jinaclip): base get_tensors failed, falling back: %s", e)
6597+
if yielded_any:
6598+
return
6599+
6600+
candidates = [
6601+
self.dir_model / "pytorch_model.bin",
6602+
self.dir_model / "vision_model_weights.bin",
6603+
]
6604+
model_path = next((p for p in candidates if p.exists()), None)
6605+
if model_path is None:
6606+
raise FileNotFoundError(f"mmproj(jinaclip): no model weights found in {self.dir_model}")
6607+
try:
6608+
state_dict = torch.load(model_path, map_location="cpu", weights_only=True)
6609+
except TypeError:
6610+
state_dict = torch.load(model_path, map_location="cpu")
6611+
6612+
for name, tensor in state_dict.items():
6613+
yield name, tensor
6614+
6615+
def _should_be_f32(self, gguf_name: str) -> bool:
6616+
patterns = (
6617+
".ln1.weight", ".ln1.bias",
6618+
".ln2.weight", ".ln2.bias",
6619+
".attn_ln.weight", ".attn_ln.bias",
6620+
".ffn_norm.weight", ".ffn_norm.bias",
6621+
"v.patch_embd.proj.bias",
6622+
)
6623+
return any(p in gguf_name for p in patterns)
6624+
6625+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
6626+
del bid # unused
6627+
6628+
src = name
6629+
if src.startswith('v.') or src.startswith('mm.'):
6630+
return [(src, data_torch)]
6631+
6632+
# Drop 'vision_model.' prefix if present
6633+
src_no_vm = self._strip_vm_prefix(src)
6634+
6635+
# Top-level direct mappings — use gguf constants directly for canonical names
6636+
if src_no_vm == 'cls_token':
6637+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_CLS]
6638+
return [(base, data_torch)]
6639+
if src_no_vm.startswith('patch_embed.proj.'):
6640+
suffix = src_no_vm.split('.')[-1]
6641+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_PATCH]
6642+
return [(f'{base}.{suffix}', data_torch)]
6643+
if src_no_vm == 'pos_embed':
6644+
pos_name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_ENC_EMBD_POS] + '.weight'
6645+
return [(pos_name, data_torch)]
6646+
if src_no_vm.startswith('norm.'):
6647+
suffix = src_no_vm.split('.')[-1]
6648+
base = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_POST_NORM]
6649+
return [(f'{base}.{suffix}', data_torch)]
6650+
6651+
if src_no_vm.startswith('blocks.'):
6652+
parts = src_no_vm.split('.')
6653+
if len(parts) >= 3 and parts[1].isdigit():
6654+
layer = int(parts[1])
6655+
rest = '.'.join(parts[2:])
6656+
mapped = self._map_block_tensor(layer, rest, data_torch, name)
6657+
if mapped is not None:
6658+
return mapped
6659+
6660+
try:
6661+
return [(self.map_tensor_name(name), data_torch)]
6662+
except Exception:
6663+
logger.debug("mmproj(jinaclip): skip unmapped tensor %s", name)
6664+
return []
6665+
6666+
64086667
@ModelBase.register("OpenELMForCausalLM")
64096668
class OpenELMModel(TextModel):
64106669
model_arch = gguf.MODEL_ARCH.OPENELM

0 commit comments

Comments
 (0)