Skip to content

Commit 7b1524b

Browse files
Merge pull request #65 from menloresearch/update-dev-from-master-2025-04-23-00-08
Sync master with upstream release b5170
2 parents 13bc962 + 658987c commit 7b1524b

33 files changed

+1400
-1016
lines changed

SECURITY.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ To protect sensitive data from potential leaks or unauthorized access, it is cru
4040
### Untrusted environments or networks
4141

4242
If you can't run your models in a secure and isolated environment or if it must be exposed to an untrusted network, make sure to take the following security precautions:
43-
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value
43+
* Do not use the RPC backend, [rpc-server](https://github.com/ggml-org/llama.cpp/tree/master/examples/rpc) and [llama-server](https://github.com/ggml-org/llama.cpp/tree/master/examples/server) functionality (see https://github.com/ggml-org/llama.cpp/pull/13061).
44+
* Confirm the hash of any downloaded artifact (e.g. pre-trained model weights) matches a known-good value.
4445
* Encrypt your data if sending it over the network.
4546

4647
### Multi-Tenant environments

common/arg.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -976,14 +976,13 @@ static void common_params_print_completion(common_params_context & ctx_arg) {
976976
"llama-gritlm",
977977
"llama-imatrix",
978978
"llama-infill",
979-
"llama-llava-cli",
979+
"llama-mtmd-cli",
980980
"llama-llava-clip-quantize-cli",
981981
"llama-lookahead",
982982
"llama-lookup",
983983
"llama-lookup-create",
984984
"llama-lookup-merge",
985985
"llama-lookup-stats",
986-
"llama-minicpmv-cli",
987986
"llama-parallel",
988987
"llama-passkey",
989988
"llama-perplexity",

convert_hf_to_gguf.py

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -419,8 +419,12 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
419419
def load_hparams(dir_model: Path):
420420
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
421421
hparams = json.load(f)
422+
architectures = hparams.get("architectures")
422423
if "text_config" in hparams:
423424
hparams = {**hparams, **hparams["text_config"]}
425+
if architectures is not None:
426+
# preserve "architectures" from root level config
427+
hparams["architectures"] = architectures
424428
return hparams
425429

426430
@classmethod
@@ -1061,6 +1065,8 @@ def _set_vocab_builtin(self, model_name: Literal["gpt-neox", "llama-spm"], vocab
10611065
class VisionModel(ModelBase):
10621066
model_arch = gguf.MODEL_ARCH.CLIP_VISION
10631067
n_text_embd = 0
1068+
preprocessor_config: dict[str, Any]
1069+
global_config: dict[str, Any]
10641070

10651071
def __init__(self, *args, **kwargs):
10661072
super().__init__(*args, **kwargs)
@@ -1075,24 +1081,33 @@ def __init__(self, *args, **kwargs):
10751081

10761082
if "vision_config" not in self.hparams:
10771083
raise ValueError("vision_config not found in hparams")
1078-
# move vision config to the top level
1084+
# move vision config to the top level, while preserving the original hparams in global_config
1085+
self.global_config = self.hparams
10791086
self.hparams = self.hparams["vision_config"]
10801087

1088+
# load preprocessor config
1089+
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
1090+
self.preprocessor_config = json.load(f)
1091+
10811092
def set_type(self):
10821093
self.gguf_writer.add_type(gguf.GGUFType.CLIP_VISION)
10831094

10841095
def set_gguf_parameters(self):
10851096
self.gguf_writer.add_file_type(self.ftype)
1086-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PROJECTION_DIM, self.n_embd_text)
1087-
self.gguf_writer.add_bool(gguf.Keys.ClipVision.HAS_VISION_ENCODER, True)
1097+
self.gguf_writer.add_vision_projection_dim(self.n_embd_text)
1098+
self.gguf_writer.add_vision_has_vision_encoder(True)
10881099

10891100
# vision config
1090-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.IMAGE_SIZE, self.find_hparam(["image_size"]))
1091-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.PATCH_SIZE, self.find_hparam(["patch_size"]))
1092-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.EMBEDDING_LENGTH, self.find_hparam(["hidden_size"]))
1093-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.FEED_FORWARD_LENGTH, self.find_hparam(["intermediate_size"]))
1094-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.BLOCK_COUNT, self.find_hparam(["num_hidden_layers"]))
1095-
self.gguf_writer.add_uint32(gguf.Keys.ClipVision.Attention.HEAD_COUNT, self.find_hparam(["num_attention_heads"]))
1101+
self.gguf_writer.add_vision_image_size(self.find_hparam(["image_size"]))
1102+
self.gguf_writer.add_vision_patch_size(self.find_hparam(["patch_size"]))
1103+
self.gguf_writer.add_vision_embedding_length(self.find_hparam(["hidden_size"]))
1104+
self.gguf_writer.add_vision_feed_forward_length(self.find_hparam(["intermediate_size"]))
1105+
self.gguf_writer.add_vision_block_count(self.find_hparam(["num_hidden_layers"]))
1106+
self.gguf_writer.add_vision_head_count(self.find_hparam(["num_attention_heads"]))
1107+
1108+
# preprocessor config
1109+
self.gguf_writer.add_vision_image_mean(self.preprocessor_config["image_mean"])
1110+
self.gguf_writer.add_vision_image_std(self.preprocessor_config["image_mean"])
10961111

10971112
def write_vocab(self):
10981113
raise ValueError("VisionModel does not support vocab writing")
@@ -1703,11 +1718,23 @@ def prepare_tensors(self):
17031718
raise ValueError(f"Unprocessed norms: {norms}")
17041719

17051720

1706-
@ModelBase.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
1721+
@ModelBase.register(
1722+
"LLaMAForCausalLM",
1723+
"LlamaForCausalLM",
1724+
"MistralForCausalLM",
1725+
"MixtralForCausalLM",
1726+
"Idefics3ForConditionalGeneration",
1727+
"SmolVLMForConditionalGeneration")
17071728
class LlamaModel(TextModel):
17081729
model_arch = gguf.MODEL_ARCH.LLAMA
17091730
undo_permute = True
17101731

1732+
def __init__(self, *args, **kwargs):
1733+
super().__init__(*args, **kwargs)
1734+
# fix for SmolVLM2, missing `num_attention_heads` in config.json
1735+
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
1736+
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
1737+
17111738
def set_vocab(self):
17121739
try:
17131740
self._set_vocab_sentencepiece()
@@ -1770,6 +1797,12 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
17701797
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
17711798
n_head = self.hparams["num_attention_heads"]
17721799
n_kv_head = self.hparams.get("num_key_value_heads")
1800+
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
1801+
1802+
if is_vision_tensor:
1803+
return [] # skip vision tensors
1804+
elif name.startswith("model.text_model"):
1805+
name = name.replace("text_model.", "") # for SmolVLM
17731806

17741807
if self.undo_permute:
17751808
if name.endswith(("q_proj.weight", "q_proj.bias")):
@@ -1852,6 +1885,41 @@ def prepare_tensors(self):
18521885
raise ValueError(f"Unprocessed experts: {experts}")
18531886

18541887

1888+
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
1889+
class SmolVLMModel(VisionModel):
1890+
def __init__(self, *args, **kwargs):
1891+
super().__init__(*args, **kwargs)
1892+
# fix for SmolVLM2, missing some keys in config.json
1893+
# default values are taken from transformers code
1894+
if self.hparams["model_type"] == "smolvlm_vision":
1895+
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1152)
1896+
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
1897+
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 3072)
1898+
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 12)
1899+
1900+
def set_gguf_parameters(self):
1901+
super().set_gguf_parameters()
1902+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.IDEFICS3)
1903+
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-5))
1904+
self.gguf_writer.add_vision_projector_scale_factor(self.global_config.get("scale_factor", 2))
1905+
self.gguf_writer.add_vision_use_gelu(True)
1906+
1907+
def tensor_force_quant(self, name, new_name, bid, n_dims):
1908+
del bid, new_name, n_dims # unused
1909+
if ".embeddings." in name:
1910+
return gguf.GGMLQuantizationType.F32
1911+
return False
1912+
1913+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1914+
del bid # unused
1915+
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
1916+
1917+
if is_vision_tensor:
1918+
return [(self.map_tensor_name(name), data_torch)]
1919+
1920+
return [] # skip other tensors
1921+
1922+
18551923
@ModelBase.register("Llama4ForConditionalGeneration")
18561924
class Llama4Model(LlamaModel):
18571925
model_arch = gguf.MODEL_ARCH.LLAMA4
@@ -3591,12 +3659,10 @@ class Gemma3VisionModel(VisionModel):
35913659
def set_gguf_parameters(self):
35923660
super().set_gguf_parameters()
35933661
hparams = self.hparams
3594-
self.gguf_writer.add_string(gguf.Keys.ClipVision.PROJECTOR_TYPE, "gemma3")
3662+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.GEMMA3)
35953663
# default values below are taken from HF tranformers code
3596-
self.gguf_writer.add_float32(gguf.Keys.ClipVision.Attention.LAYERNORM_EPS, hparams.get("layer_norm_eps", 1e-6))
3597-
self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_MEAN, [0.5, 0.5, 0.5])
3598-
self.gguf_writer.add_array(gguf.Keys.ClipVision.IMAGE_STD, [0.5, 0.5, 0.5])
3599-
self.gguf_writer.add_bool (gguf.Keys.ClipVision.USE_GELU, True)
3664+
self.gguf_writer.add_vision_attention_layernorm_eps(hparams.get("layer_norm_eps", 1e-6))
3665+
self.gguf_writer.add_vision_use_gelu(True)
36003666

36013667
def tensor_force_quant(self, name, new_name, bid, n_dims):
36023668
del bid, new_name, n_dims # unused
@@ -3614,10 +3680,6 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
36143680
or name.startswith("multimodal_projector.") or name.startswith("vision_model."):
36153681
# process vision tensors
36163682
name = name.replace("_weight", ".weight")
3617-
if "fc1" in name:
3618-
name = name.replace("fc1", "fc2")
3619-
else:
3620-
name = name.replace("fc2", "fc1")
36213683

36223684
# correct norm value ; only this "soft_emb_norm" need to be corrected as it's part of Gemma projector
36233685
# the other norm values are part of SigLIP model, and they are already correct

examples/llava/MobileVLM-README.md renamed to docs/multimodal/MobileVLM.md

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,15 @@ The implementation is based on llava, and is compatible with llava and mobileVLM
99
Notice: The overall process of model inference for both **MobileVLM** and **MobileVLM_V2** models is the same, but the process of model conversion is a little different. Therefore, using **MobileVLM-1.7B** as an example, the different conversion step will be shown.
1010

1111
## Usage
12-
Build with cmake or run `make llama-llava-cli` to build it.
1312

14-
After building, run: `./llama-llava-cli` to see the usage. For example:
13+
Build the `llama-mtmd-cli` binary.
14+
15+
After building, run: `./llama-mtmd-cli` to see the usage. For example:
1516

1617
```sh
17-
./llama-llava-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \
18+
./llama-mtmd-cli -m MobileVLM-1.7B/ggml-model-q4_k.gguf \
1819
--mmproj MobileVLM-1.7B/mmproj-model-f16.gguf \
19-
--image path/to/an/image.jpg \
20-
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWho is the author of this book? Answer the question using a single word or phrase. ASSISTANT:"
20+
--chat-template deepseek
2121
```
2222

2323
## Model conversion
@@ -82,7 +82,7 @@ refer to `android/adb_run.sh`, modify resources' `name` and `path`
8282
### case 1
8383
**input**
8484
```sh
85-
/data/local/tmp/llama-llava-cli \
85+
/data/local/tmp/llama-mtmd-cli \
8686
-m /data/local/tmp/ggml-model-q4_k.gguf \
8787
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
8888
-t 4 \
@@ -102,7 +102,7 @@ llama_print_timings: total time = 34731.93 ms
102102
### case 2
103103
**input**
104104
```sh
105-
/data/local/tmp/llama-llava-cli \
105+
/data/local/tmp/llama-mtmd-cli \
106106
-m /data/local/tmp/ggml-model-q4_k.gguf \
107107
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
108108
-t 4 \
@@ -123,10 +123,10 @@ llama_print_timings: total time = 34570.79 ms
123123

124124
## Some result on Android with `Snapdragon 778G` chip
125125
### MobileVLM-1.7B case
126-
#### llava-cli release-b2005
126+
#### mtmd-cli release-b2005
127127
**input**
128128
```sh
129-
/data/local/tmp/llama-llava-cli \
129+
/data/local/tmp/llama-mtmd-cli \
130130
-m /data/local/tmp/ggml-model-q4_k.gguf \
131131
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
132132
-t 4 \
@@ -147,7 +147,7 @@ llama_print_timings: prompt eval time = 8119.49 ms / 191 tokens ( 42.51 m
147147
llama_print_timings: eval time = 1005.75 ms / 14 runs ( 71.84 ms per token, 13.92 tokens per second)
148148
llama_print_timings: total time = 28038.34 ms / 205 tokens
149149
```
150-
#### llava-cli latest-version
150+
#### mtmd-cli latest-version
151151
**input**
152152

153153
Just the same as above.
@@ -169,7 +169,7 @@ llama_print_timings: eval time = 43894.02 ms / 13 runs ( 3376.46 m
169169
llama_print_timings: total time = 865441.76 ms / 204 tokens
170170
```
171171
### MobileVLM_V2-1.7B case
172-
#### llava-cli release-2005b
172+
#### mtmd-cli release-2005b
173173
**input**
174174

175175
Just the same as above.
@@ -200,7 +200,7 @@ make GGML_CUDA=1 CUDA_DOCKER_ARCH=sm_87 GGML_CUDA_F16=1 -j 32
200200
### case 1
201201
**input**
202202
```sh
203-
./llama-llava-cli \
203+
./llama-mtmd-cli \
204204
-m /data/local/tmp/ggml-model-q4_k.gguf \
205205
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
206206
--image /data/local/tmp/demo.jpeg \
@@ -224,7 +224,7 @@ llama_print_timings: total time = 1352.63 ms / 252 tokens
224224
### case 2
225225
**input**
226226
```sh
227-
./llama-llava-cli \
227+
./llama-mtmd-cli \
228228
-m /data/local/tmp/ggml-model-q4_k.gguf \
229229
--mmproj /data/local/tmp/mmproj-model-f16.gguf \
230230
-p "A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\nWhat is in the image? ASSISTANT:" \

examples/llava/README-gemma3.md renamed to docs/multimodal/gemma3.md

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,12 @@ llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
2626

2727
## How to get mmproj.gguf?
2828

29+
Simply to add `--mmproj` in when converting model via `convert_hf_to_gguf.py`:
30+
2931
```bash
3032
cd gemma-3-4b-it
31-
python ../llama.cpp/examples/llava/gemma3_convert_encoder_to_gguf.py .
32-
33-
# output file is mmproj.gguf
33+
python ../llama.cpp/convert_hf_to_gguf.py --outfile model.gguf --outtype f16 --mmproj .
34+
# output file: mmproj-model.gguf
3435
```
3536

3637
## How to run it?

examples/llava/README-glmedge.md renamed to docs/multimodal/glmedge.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,12 @@
33
Currently this implementation supports [glm-edge-v-2b](https://huggingface.co/THUDM/glm-edge-v-2b) and [glm-edge-v-5b](https://huggingface.co/THUDM/glm-edge-v-5b).
44

55
## Usage
6-
Build with cmake or run `make llama-llava-cli` to build it.
6+
Build the `llama-mtmd-cli` binary.
77

8-
After building, run: `./llama-llava-cli` to see the usage. For example:
8+
After building, run: `./llama-mtmd-cli` to see the usage. For example:
99

1010
```sh
11-
./llama-llava-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf --image img_path/image.jpg -p "<|system|>\n system prompt <image><|user|>\n prompt <|assistant|>\n"
11+
./llama-mtmd-cli -m model_path/ggml-model-f16.gguf --mmproj model_path/mmproj-model-f16.gguf
1212
```
1313

1414
**note**: A lower temperature like 0.1 is recommended for better quality. add `--temp 0.1` to the command to do so.

examples/llava/README-granitevision.md renamed to docs/multimodal/granitevision.md

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,15 +176,11 @@ Note that currently you cannot quantize the visual encoder because granite visio
176176

177177

178178
### 5. Running the Model in Llama cpp
179-
Build llama cpp normally; you should have a target binary named `llama-llava-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner.
179+
Build llama cpp normally; you should have a target binary named `llama-mtmd-cli`, which you can pass two binaries to. As an example, we pass the the llama.cpp banner.
180180

181181
```bash
182-
$ ./build/bin/llama-llava-cli -m $LLM_GGUF_PATH \
182+
$ ./build/bin/llama-mtmd-cli -m $LLM_GGUF_PATH \
183183
--mmproj $VISUAL_GGUF_PATH \
184-
--image ./media/llama0-banner.png \
185184
-c 16384 \
186-
-p "<|system|>\nA chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.\n<|user|>\n\<image>\nWhat does the text in this image say?\n<|assistant|>\n" \
187185
--temp 0
188186
```
189-
190-
Sample output: `The text in the image reads "LLAMA C++ Can it run DOOM Llama?"`

0 commit comments

Comments
 (0)