Skip to content

Commit 28a2723

Browse files
committed
merged pixtral support, not fully working
2 parents 8f1edcb + 5630406 commit 28a2723

24 files changed

+1273
-506
lines changed

convert_hf_to_gguf.py

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -776,6 +776,9 @@ def get_vocab_base_pre(self, tokenizer) -> str:
776776
if chkhsh == "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2":
777777
# ref: https://huggingface.co/THUDM/glm-4-9b-hf
778778
res = "glm4"
779+
if chkhsh == "0e9433cbbb161f89e264eb32e8e64bfe69e834973ffca5d41d3948a604a3e2a3":
780+
# ref: https://huggingface.co/mistral-community/pixtral-12b
781+
res = "pixtral"
779782

780783
if res is None:
781784
logger.warning("\n")
@@ -1724,7 +1727,8 @@ def prepare_tensors(self):
17241727
"MistralForCausalLM",
17251728
"MixtralForCausalLM",
17261729
"Idefics3ForConditionalGeneration",
1727-
"SmolVLMForConditionalGeneration")
1730+
"SmolVLMForConditionalGeneration",
1731+
"LlavaForConditionalGeneration")
17281732
class LlamaModel(TextModel):
17291733
model_arch = gguf.MODEL_ARCH.LLAMA
17301734
undo_permute = True
@@ -1734,6 +1738,10 @@ def __init__(self, *args, **kwargs):
17341738
# fix for SmolVLM2, missing `num_attention_heads` in config.json
17351739
if self.hparams["architectures"][0] == "SmolVLMForConditionalGeneration":
17361740
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
1741+
# fix for Pixtral, missing `num_attention_heads` in config.json
1742+
if self.hparams["architectures"][0] == "LlavaForConditionalGeneration" \
1743+
and self.hparams.get("model_type") == "mistral":
1744+
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
17371745

17381746
def set_vocab(self):
17391747
try:
@@ -1797,12 +1805,17 @@ def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
17971805
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
17981806
n_head = self.hparams["num_attention_heads"]
17991807
n_kv_head = self.hparams.get("num_key_value_heads")
1800-
is_vision_tensor = "vision_tower" in name or "vision_model" in name or "model.connector" in name
1808+
is_vision_tensor = "vision_tower" in name \
1809+
or "vision_model" in name \
1810+
or "model.connector" in name \
1811+
or "multi_modal_projector" in name
18011812

18021813
if is_vision_tensor:
18031814
return [] # skip vision tensors
18041815
elif name.startswith("model.text_model"):
18051816
name = name.replace("text_model.", "") # for SmolVLM
1817+
elif name.startswith("language_model."):
1818+
name = name.replace("language_model.", "") # for the rest
18061819

18071820
if self.undo_permute:
18081821
if name.endswith(("q_proj.weight", "q_proj.bias")):
@@ -1885,6 +1898,55 @@ def prepare_tensors(self):
18851898
raise ValueError(f"Unprocessed experts: {experts}")
18861899

18871900

1901+
@ModelBase.register("LlavaForConditionalGeneration")
1902+
class LlavaVisionModel(VisionModel):
1903+
img_break_tok_id = -1
1904+
1905+
def __init__(self, *args, **kwargs):
1906+
super().__init__(*args, **kwargs)
1907+
if self.hparams["model_type"] == "pixtral":
1908+
# fix missing config.json values
1909+
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 16)
1910+
self.hparams["num_hidden_layers"] = self.hparams.get("num_hidden_layers", 24)
1911+
self.hparams["intermediate_size"] = self.hparams.get("intermediate_size", 4096)
1912+
self.hparams["hidden_size"] = self.hparams.get("hidden_size", 1024)
1913+
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
1914+
self.img_break_tok_id = 12 # see tokenizer_config.json
1915+
else:
1916+
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
1917+
1918+
def set_gguf_parameters(self):
1919+
super().set_gguf_parameters()
1920+
hparams = self.hparams
1921+
if hparams["model_type"] == "pixtral":
1922+
self.gguf_writer.add_vision_projector_type(gguf.VisionProjectorType.PIXTRAL)
1923+
# default values below are taken from HF tranformers code
1924+
self.gguf_writer.add_vision_attention_layernorm_eps(hparams["layer_norm_eps"])
1925+
self.gguf_writer.add_vision_use_silu(True)
1926+
1927+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
1928+
del bid # unused
1929+
n_head = self.hparams["num_attention_heads"]
1930+
n_kv_head = n_head
1931+
1932+
if name.startswith("multi_modal_projector.") or name.startswith("vision_tower."):
1933+
# process vision tensors
1934+
if name.endswith(("q_proj.weight", "q_proj.bias")):
1935+
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
1936+
if name.endswith(("k_proj.weight", "k_proj.bias")):
1937+
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
1938+
return [(self.map_tensor_name(name), data_torch)]
1939+
1940+
if self.img_break_tok_id > 0 and "embed_tokens.weight" in name:
1941+
logger.info(f"Extracting [IMG_BREAK] token embedding from {name}")
1942+
# for pixtral model, we need to extract the [IMG_BREAK] token embedding
1943+
img_break_embd = data_torch[self.img_break_tok_id]
1944+
name = gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.V_TOK_EMBD_IMG_BREAK]
1945+
return [(self.map_tensor_name(name), img_break_embd)]
1946+
1947+
return [] # skip other tensors
1948+
1949+
18881950
@ModelBase.register("Idefics3ForConditionalGeneration", "SmolVLMForConditionalGeneration")
18891951
class SmolVLMModel(VisionModel):
18901952
def __init__(self, *args, **kwargs):
@@ -5079,10 +5141,25 @@ class Glm4Model(TextModel):
50795141
model_arch = gguf.MODEL_ARCH.GLM4
50805142

50815143
def set_vocab(self):
5082-
self._set_vocab_gpt2()
5144+
from transformers import AutoTokenizer
5145+
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
5146+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
5147+
tokens, toktypes, tokpre = self.get_vocab_base()
5148+
self.gguf_writer.add_tokenizer_model("gpt2")
5149+
self.gguf_writer.add_tokenizer_pre(tokpre)
5150+
self.gguf_writer.add_token_list(tokens)
5151+
self.gguf_writer.add_token_types(toktypes)
5152+
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
5153+
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
5154+
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"])
5155+
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"])
5156+
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"])
5157+
special_vocab.add_to_gguf(self.gguf_writer)
50835158

50845159
def set_gguf_parameters(self):
50855160
super().set_gguf_parameters()
5161+
rope_dim = self.hparams["head_dim"]
5162+
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)))
50865163
if self.hparams.get("rope_scaling") is not None and "factor" in self.hparams["rope_scaling"]:
50875164
if self.hparams["rope_scaling"].get("type") == "yarn":
50885165
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)

convert_hf_to_gguf_update.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ class TOKENIZER_TYPE(IntEnum):
115115
{"name": "bailingmoe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inclusionAI/Ling-lite", },
116116
{"name": "llama4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
117117
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", },
118+
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
118119
]
119120

120121

docs/multimodal/gemma3.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ You can use pre-quantized model from [ggml-org](https://huggingface.co/ggml-org)
1111
```bash
1212
# build
1313
cmake -B build
14-
cmake --build build --target llama-gemma3-cli
14+
cmake --build build --target llama-mtmd-cli
1515

1616
# alternatively, install from brew (MacOS)
1717
brew install llama.cpp
1818

1919
# run it
20-
llama-gemma3-cli -hf ggml-org/gemma-3-4b-it-GGUF
21-
llama-gemma3-cli -hf ggml-org/gemma-3-12b-it-GGUF
22-
llama-gemma3-cli -hf ggml-org/gemma-3-27b-it-GGUF
20+
llama-mtmd-cli -hf ggml-org/gemma-3-4b-it-GGUF
21+
llama-mtmd-cli -hf ggml-org/gemma-3-12b-it-GGUF
22+
llama-mtmd-cli -hf ggml-org/gemma-3-27b-it-GGUF
2323

2424
# note: 1B model does not support vision
2525
```
@@ -44,8 +44,8 @@ What you need:
4444
```bash
4545
# build
4646
cmake -B build
47-
cmake --build build --target llama-gemma3-cli
47+
cmake --build build --target llama-mtmd-cli
4848

4949
# run it
50-
./build/bin/llama-gemma3-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
50+
./build/bin/llama-mtmd-cli -m {text_model}.gguf --mmproj mmproj.gguf --image your_image.jpg
5151
```

examples/llava/clip-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
#define TN_ATTN_V "%s.blk.%d.attn_v.%s"
6565
#define TN_ATTN_OUTPUT "%s.blk.%d.attn_out.%s"
6666
#define TN_FFN_DOWN "%s.blk.%d.ffn_down.%s"
67+
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
6768
#define TN_FFN_UP "%s.blk.%d.ffn_up.%s"
6869
#define TN_FFN_GATE "%s.blk.%d.ffn_gate.%s"
6970
#define TN_LN_1 "%s.blk.%d.ln1.%s"
@@ -78,6 +79,7 @@
7879
#define TN_MM_INP_PROJ "mm.input_projection.weight" // gemma3
7980
#define TN_MM_SOFT_EMB_N "mm.soft_emb_norm.weight" // gemma3
8081
#define TN_MM_PROJECTOR "mm.model.fc.weight" // idefics3
82+
#define TN_TOK_IMG_BREAK "v.token_embd.img_break" // pixtral
8183

8284
// mimicpmv
8385
#define TN_MINICPMV_POS_EMBD_K "resampler.pos_embed_k"
@@ -106,6 +108,7 @@ enum projector_type {
106108
PROJECTOR_TYPE_MERGER,
107109
PROJECTOR_TYPE_GEMMA3,
108110
PROJECTOR_TYPE_IDEFICS3,
111+
PROJECTOR_TYPE_PIXTRAL,
109112
PROJECTOR_TYPE_UNKNOWN,
110113
};
111114

@@ -118,6 +121,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
118121
{ PROJECTOR_TYPE_MERGER, "qwen2vl_merger"},
119122
{ PROJECTOR_TYPE_GEMMA3, "gemma3"},
120123
{ PROJECTOR_TYPE_IDEFICS3, "idefics3"},
124+
{ PROJECTOR_TYPE_PIXTRAL, "pixtral"},
121125
};
122126

123127
static projector_type clip_projector_type_from_string(const std::string & str) {

0 commit comments

Comments
 (0)