Skip to content

Commit b024b81

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents eef5963 + 5a63980 commit b024b81

24 files changed

+562
-241
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
1717
## Hot topics
1818

1919
- **GGML developer experience survey (organized and reviewed by NVIDIA):** [link](https://forms.gle/Gasw3cRgyhNEnrwK9)
20-
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli` and `gemma3-cli` https://github.com/ggml-org/llama.cpp/pull/13012, `libllava` will be deprecated
20+
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141]((https://github.com/ggml-org/llama.cpp/pull/13141))), `libllava` will be deprecated
2121
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
2222
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
2323
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim

convert_hf_to_gguf.py

Lines changed: 135 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ class ModelBase:
7878
# subclasses should define this!
7979
model_arch: gguf.MODEL_ARCH
8080

81-
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, is_big_endian: bool = False,
81+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, *, is_big_endian: bool = False,
8282
use_temp_file: bool = False, eager: bool = False,
8383
metadata_override: Path | None = None, model_name: str | None = None,
8484
split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False,
@@ -454,13 +454,6 @@ def from_model_architecture(cls, arch: str, model_type = ModelType.TEXT) -> type
454454

455455

456456
class TextModel(ModelBase):
457-
@classmethod
458-
def __init_subclass__(cls):
459-
# can't use an abstract property, because overriding it without type errors
460-
# would require using decorated functions instead of simply defining the property
461-
if "model_arch" not in cls.__dict__:
462-
raise TypeError(f"Missing property 'model_arch' for {cls.__name__!r}")
463-
464457
def set_vocab(self):
465458
self._set_vocab_gpt2()
466459

@@ -3373,14 +3366,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
33733366

33743367
return [(self.map_tensor_name(name), data_torch)]
33753368

3376-
3377-
@ModelBase.register("RobertaModel")
3378-
class RobertaModel(BertModel):
3379-
model_arch = gguf.MODEL_ARCH.BERT
3380-
3381-
def __init__(self, *args, **kwargs):
3382-
super().__init__(*args, **kwargs)
3383-
3369+
def _xlmroberta_tokenizer_init(self) -> None:
33843370
# we need the pad_token_id to know how to chop down position_embd matrix
33853371
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
33863372
self._position_offset = 1 + pad_token_id
@@ -3389,82 +3375,7 @@ def __init__(self, *args, **kwargs):
33893375
else:
33903376
self._position_offset = None
33913377

3392-
def set_vocab(self):
3393-
"""Support BPE tokenizers for roberta models"""
3394-
bpe_tok_path = self.dir_model / "tokenizer.json"
3395-
if bpe_tok_path.exists():
3396-
self._set_vocab_gpt2()
3397-
self.gguf_writer.add_add_bos_token(True)
3398-
self.gguf_writer.add_add_eos_token(True)
3399-
3400-
# we need this to validate the size of the token_type embeddings
3401-
# though currently we are passing all zeros to the token_type embeddings
3402-
# "Sequence A" or "Sequence B"
3403-
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
3404-
3405-
else:
3406-
return super().set_vocab()
3407-
3408-
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3409-
# if name starts with "roberta.", remove the prefix
3410-
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
3411-
if name.startswith("roberta."):
3412-
name = name[8:]
3413-
3414-
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
3415-
if name == "embeddings.position_embeddings.weight":
3416-
if self._position_offset is not None:
3417-
data_torch = data_torch[self._position_offset:,:]
3418-
3419-
return super().modify_tensors(data_torch, name, bid)
3420-
3421-
3422-
@ModelBase.register("NomicBertModel")
3423-
class NomicBertModel(BertModel):
3424-
model_arch = gguf.MODEL_ARCH.NOMIC_BERT
3425-
3426-
def __init__(self, *args, **kwargs):
3427-
super().__init__(*args, **kwargs)
3428-
3429-
# the HF config claims n_ctx=8192, but it uses RoPE scaling
3430-
self.hparams["n_ctx"] = 2048
3431-
3432-
# SwigLU activation
3433-
assert self.hparams["activation_function"] == "swiglu"
3434-
# this doesn't do anything in the HF version
3435-
assert self.hparams["causal"] is False
3436-
# no bias tensors
3437-
assert self.hparams["qkv_proj_bias"] is False
3438-
assert self.hparams["mlp_fc1_bias"] is False
3439-
assert self.hparams["mlp_fc2_bias"] is False
3440-
# norm at end of layer
3441-
assert self.hparams["prenorm"] is False
3442-
# standard RoPE
3443-
assert self.hparams["rotary_emb_fraction"] == 1.0
3444-
assert self.hparams["rotary_emb_interleaved"] is False
3445-
assert self.hparams["rotary_emb_scale_base"] is None
3446-
3447-
def set_gguf_parameters(self):
3448-
super().set_gguf_parameters()
3449-
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
3450-
3451-
3452-
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
3453-
class XLMRobertaModel(BertModel):
3454-
model_arch = gguf.MODEL_ARCH.BERT
3455-
3456-
def __init__(self, *args, **kwargs):
3457-
super().__init__(*args, **kwargs)
3458-
3459-
# we need the pad_token_id to know how to chop down position_embd matrix
3460-
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
3461-
self._position_offset = 1 + pad_token_id
3462-
if "max_position_embeddings" in self.hparams:
3463-
self.hparams["max_position_embeddings"] -= self._position_offset
3464-
else:
3465-
self._position_offset = None
3466-
3467-
def set_vocab(self):
3378+
def _xlmroberta_set_vocab(self) -> None:
34683379
# to avoid TypeError: Descriptors cannot be created directly
34693380
# exception when importing sentencepiece_model_pb2
34703381
os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
@@ -3546,6 +3457,138 @@ def set_vocab(self):
35463457
self.gguf_writer.add_add_bos_token(True)
35473458
self.gguf_writer.add_add_eos_token(True)
35483459

3460+
3461+
@ModelBase.register("RobertaModel")
3462+
class RobertaModel(BertModel):
3463+
model_arch = gguf.MODEL_ARCH.BERT
3464+
3465+
def __init__(self, *args, **kwargs):
3466+
super().__init__(*args, **kwargs)
3467+
3468+
# we need the pad_token_id to know how to chop down position_embd matrix
3469+
if (pad_token_id := self.hparams.get("pad_token_id")) is not None:
3470+
self._position_offset = 1 + pad_token_id
3471+
if "max_position_embeddings" in self.hparams:
3472+
self.hparams["max_position_embeddings"] -= self._position_offset
3473+
else:
3474+
self._position_offset = None
3475+
3476+
def set_vocab(self):
3477+
"""Support BPE tokenizers for roberta models"""
3478+
bpe_tok_path = self.dir_model / "tokenizer.json"
3479+
if bpe_tok_path.exists():
3480+
self._set_vocab_gpt2()
3481+
self.gguf_writer.add_add_bos_token(True)
3482+
self.gguf_writer.add_add_eos_token(True)
3483+
3484+
# we need this to validate the size of the token_type embeddings
3485+
# though currently we are passing all zeros to the token_type embeddings
3486+
# "Sequence A" or "Sequence B"
3487+
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
3488+
3489+
else:
3490+
return super().set_vocab()
3491+
3492+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3493+
# if name starts with "roberta.", remove the prefix
3494+
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main
3495+
if name.startswith("roberta."):
3496+
name = name[8:]
3497+
3498+
# position embeddings start at pad_token_id + 1, so just chop down the weight tensor
3499+
if name == "embeddings.position_embeddings.weight":
3500+
if self._position_offset is not None:
3501+
data_torch = data_torch[self._position_offset:,:]
3502+
3503+
return super().modify_tensors(data_torch, name, bid)
3504+
3505+
3506+
@ModelBase.register("NomicBertModel")
3507+
class NomicBertModel(BertModel):
3508+
def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path, **kwargs: Any):
3509+
hparams = kwargs.pop("hparams", None)
3510+
if hparams is None:
3511+
hparams = ModelBase.load_hparams(dir_model)
3512+
3513+
self.is_moe = bool(hparams.get("moe_every_n_layers"))
3514+
self.model_arch = gguf.MODEL_ARCH.NOMIC_BERT_MOE if self.is_moe else gguf.MODEL_ARCH.NOMIC_BERT
3515+
3516+
super().__init__(dir_model, ftype, fname_out, hparams=hparams, **kwargs)
3517+
3518+
self._tokenizer_is_xlmroberta = self._is_tokenizer_xlmroberta()
3519+
if self._tokenizer_is_xlmroberta:
3520+
self._xlmroberta_tokenizer_init()
3521+
3522+
# the HF config claims n_ctx=8192, but it uses RoPE scaling
3523+
self.hparams["n_ctx"] = 2048
3524+
3525+
assert self.hparams["activation_function"] == "gelu" if self.is_moe else "swiglu"
3526+
3527+
# this doesn't do anything in the HF version
3528+
assert self.hparams["causal"] is False
3529+
# no bias tensors unless MoE
3530+
assert self.hparams["qkv_proj_bias"] == self.is_moe
3531+
assert self.hparams["mlp_fc1_bias"] == self.is_moe
3532+
assert self.hparams["mlp_fc2_bias"] == self.is_moe
3533+
3534+
# norm at end of layer
3535+
assert self.hparams["prenorm"] is False
3536+
# standard RoPE
3537+
assert self.hparams["rotary_emb_fraction"] == 1.0
3538+
assert self.hparams["rotary_emb_interleaved"] is False
3539+
assert self.hparams["rotary_emb_scale_base"] is None
3540+
3541+
def set_vocab(self) -> None:
3542+
if self._tokenizer_is_xlmroberta:
3543+
return self._xlmroberta_set_vocab()
3544+
return super().set_vocab()
3545+
3546+
def modify_tensors(self, data_torch: torch.Tensor, name: str, bid: int | None) -> Iterable[tuple[str, torch.Tensor]]:
3547+
# If the tensor is an experts bias tensor, skip it by returning an empty list.
3548+
if "mlp.experts.bias" in name:
3549+
return [] # Explicitly return an empty list.
3550+
3551+
if "mlp.experts.mlp.w1" in name:
3552+
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
3553+
name += ".weight"
3554+
3555+
if "mlp.experts.mlp.w2" in name:
3556+
data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"])
3557+
data_torch = data_torch.transpose(1, 2)
3558+
name += ".weight"
3559+
3560+
return [(self.map_tensor_name(name), data_torch)]
3561+
3562+
def set_gguf_parameters(self):
3563+
super().set_gguf_parameters()
3564+
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
3565+
if self.is_moe:
3566+
self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"])
3567+
self.gguf_writer.add_expert_count(self.hparams["num_experts"])
3568+
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
3569+
3570+
def _is_tokenizer_xlmroberta(self) -> bool:
3571+
with open(self.dir_model / "tokenizer.json") as f:
3572+
tokenizer_json = json.load(f)
3573+
toktyp = tokenizer_json["model"]["type"]
3574+
if toktyp == "Unigram":
3575+
return True
3576+
if toktyp == "WordPiece":
3577+
return False
3578+
raise ValueError(f"unknown tokenizer: {toktyp}")
3579+
3580+
3581+
@ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification")
3582+
class XLMRobertaModel(BertModel):
3583+
model_arch = gguf.MODEL_ARCH.BERT
3584+
3585+
def __init__(self, *args, **kwargs):
3586+
super().__init__(*args, **kwargs)
3587+
self._xlmroberta_tokenizer_init()
3588+
3589+
def set_vocab(self):
3590+
self._xlmroberta_set_vocab()
3591+
35493592
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
35503593
# if name starts with "roberta.", remove the prefix
35513594
# e.g. https://huggingface.co/BAAI/bge-reranker-v2-m3/tree/main

examples/llama-bench/llama-bench.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1133,8 +1133,6 @@ struct test {
11331133
"build_commit", "build_number", "cpu_info", "gpu_info", "backends", "model_filename",
11341134
"model_type", "model_size", "model_n_params", "n_batch", "n_ubatch", "n_threads",
11351135
"cpu_mask", "cpu_strict", "poll", "type_k", "type_v", "n_gpu_layers",
1136-
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "use_mmap",
1137-
"embeddings", "n_prompt", "n_gen", "n_depth", "test_time", "avg_ns",
11381136
"split_mode", "main_gpu", "no_kv_offload", "flash_attn", "tensor_split", "tensor_buft_overrides",
11391137
"use_mmap", "embeddings", "n_prompt", "n_gen", "n_depth", "test_time",
11401138
"avg_ns", "stddev_ns", "avg_ts", "stddev_ts",

examples/llava/CMakeLists.txt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,13 +64,7 @@ endif()
6464
add_executable(llama-llava-cli deprecation-warning.cpp)
6565
add_executable(llama-gemma3-cli deprecation-warning.cpp)
6666
add_executable(llama-minicpmv-cli deprecation-warning.cpp)
67-
68-
set(TARGET llama-qwen2vl-cli)
69-
add_executable(${TARGET} qwen2vl-cli.cpp)
70-
set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama-qwen2vl-cli)
71-
install(TARGETS ${TARGET} RUNTIME)
72-
target_link_libraries(${TARGET} PRIVATE common llava ${CMAKE_THREAD_LIBS_INIT})
73-
target_compile_features(${TARGET} PRIVATE cxx_std_17)
67+
add_executable(llama-qwen2vl-cli deprecation-warning.cpp)
7468

7569
set(TARGET llama-mtmd-cli)
7670
add_executable(${TARGET} mtmd-cli.cpp)

examples/llava/clip.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2825,15 +2825,18 @@ void clip_free(clip_ctx * ctx) {
28252825
delete ctx;
28262826
}
28272827

2828+
// deprecated
28282829
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
2829-
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
2830+
const int32_t nx = ctx->vision_model.hparams.image_size;
2831+
const int32_t ny = ctx->vision_model.hparams.image_size;
2832+
return clip_embd_nbytes_by_img(ctx, nx, ny);
28302833
}
28312834

2832-
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w) {
2835+
size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h) {
28332836
clip_image_f32 img;
28342837
img.nx = img_w;
28352838
img.ny = img_h;
2836-
return clip_n_patches_by_img(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
2839+
return clip_n_output_tokens(ctx, &img) * clip_n_mmproj_embd(ctx) * sizeof(float);
28372840
}
28382841

28392842
int32_t clip_get_image_size(const struct clip_ctx * ctx) {
@@ -2863,14 +2866,37 @@ size_t get_clip_image_grid_size(const struct clip_ctx * ctx) {
28632866
return ctx->vision_model.hparams.image_grid_pinpoints.size();
28642867
}
28652868

2869+
// deprecated
28662870
int clip_n_patches(const struct clip_ctx * ctx) {
28672871
clip_image_f32 img;
28682872
img.nx = ctx->vision_model.hparams.image_size;
28692873
img.ny = ctx->vision_model.hparams.image_size;
2870-
return clip_n_patches_by_img(ctx, &img);
2874+
return clip_n_output_tokens(ctx, &img);
28712875
}
28722876

2877+
// deprecated
28732878
int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
2879+
return clip_n_output_tokens(ctx, img);
2880+
}
2881+
2882+
int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
2883+
const auto & params = ctx->vision_model.hparams;
2884+
const int n_total = clip_n_output_tokens(ctx, img);
2885+
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
2886+
return img->nx / (params.patch_size * 2) + (int)(img->nx % params.patch_size > 0);
2887+
}
2888+
return n_total;
2889+
}
2890+
2891+
int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
2892+
const auto & params = ctx->vision_model.hparams;
2893+
if (ctx->proj_type == PROJECTOR_TYPE_QWEN2VL || ctx->proj_type == PROJECTOR_TYPE_QWEN25VL) {
2894+
return img->ny / (params.patch_size * 2) + (int)(img->ny % params.patch_size > 0);
2895+
}
2896+
return 1;
2897+
}
2898+
2899+
int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img) {
28742900
const auto & params = ctx->vision_model.hparams;
28752901

28762902
int n_patches = (params.image_size / params.patch_size) * (params.image_size / params.patch_size);

examples/llava/clip.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ CLIP_API struct clip_ctx * clip_init(const char * fname, struct clip_context_par
4747
CLIP_API void clip_free(struct clip_ctx * ctx);
4848

4949
CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx);
50-
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_h, int img_w);
50+
CLIP_API size_t clip_embd_nbytes_by_img(const struct clip_ctx * ctx, int img_w, int img_h);
5151

5252
CLIP_API int32_t clip_get_image_size (const struct clip_ctx * ctx);
5353
CLIP_API int32_t clip_get_patch_size (const struct clip_ctx * ctx);
@@ -59,9 +59,20 @@ CLIP_API const char * clip_patch_merge_type(const struct clip_ctx * ctx);
5959
CLIP_API const int32_t * clip_image_grid(const struct clip_ctx * ctx);
6060
CLIP_API size_t get_clip_image_grid_size(const struct clip_ctx * ctx);
6161

62-
CLIP_API int clip_n_patches (const struct clip_ctx * ctx);
63-
CLIP_API int clip_n_patches_by_img (const struct clip_ctx * ctx, struct clip_image_f32 * img);
64-
CLIP_API int clip_n_mmproj_embd (const struct clip_ctx * ctx);
62+
GGML_DEPRECATED(CLIP_API int clip_n_patches(const struct clip_ctx * ctx),
63+
"use clip_n_output_tokens instead");
64+
GGML_DEPRECATED(CLIP_API int clip_n_patches_by_img(const struct clip_ctx * ctx, struct clip_image_f32 * img),
65+
"use clip_n_output_tokens instead");
66+
67+
CLIP_API int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * img);
68+
69+
// for M-RoPE, this will be the number of token positions in X and Y directions
70+
// for other models, X will be the total number of tokens and Y will be 1
71+
CLIP_API int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * img);
72+
CLIP_API int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * img);
73+
74+
// this should be equal to the embedding dimension of the text model
75+
CLIP_API int clip_n_mmproj_embd(const struct clip_ctx * ctx);
6576

6677
CLIP_API int clip_uhd_num_image_embeds_col(struct clip_ctx * ctx_clip);
6778
CLIP_API void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size);

0 commit comments

Comments
 (0)