Skip to content

Commit d0068ef

Browse files
committed
add mobilevlm
1 parent 6cabdda commit d0068ef

File tree

9 files changed

+216
-67
lines changed

9 files changed

+216
-67
lines changed

convert_hf_to_gguf.py

Lines changed: 47 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
1818
from itertools import chain
1919

20-
from transformers import AutoConfig
20+
from transformers import AutoConfig, AutoImageProcessor
2121
import math
2222
import numpy as np
2323
import torch
@@ -68,9 +68,10 @@ class Model:
6868
dir_model_card: Path
6969

7070
# for vision model
71+
vision_arch: gguf.MODEL_ARCH | None = None
7172
preprocessor_config: dict[str, Any] | None = None
7273
vparams: dict[str, Any] | None = None
73-
v_tensor_map: gguf.TensorNameMap
74+
v_tensor_map: gguf.TensorNameMap | None = None
7475
v_tensor_names: set[str] | None
7576

7677
# subclasses should define this!
@@ -102,7 +103,6 @@ def __init__(self, dir_model: Path, ftype: gguf.LlamaFileType, fname_out: Path,
102103
self.metadata_override = metadata_override
103104
self.model_name = model_name
104105
self.dir_model_card = dir_model # overridden in convert_lora_to_gguf.py
105-
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
106106

107107
# Apply heuristics to figure out typical tensor encoding based on first layer tensor encoding type
108108
if self.ftype == gguf.LlamaFileType.GUESSED:
@@ -218,7 +218,7 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int |
218218

219219
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
220220
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
221-
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes)
221+
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None
222222
if new_name is not None:
223223
return new_name
224224
elif new_name_vision is not None:
@@ -488,14 +488,17 @@ def load_hparams(dir_model: Path):
488488
return hparams
489489

490490
@staticmethod
491-
def load_preprocessor_config(dir_model: Path):
491+
def load_preprocessor_config(dir_or_model_id: Path | str):
492492
# TODO: this varies vastly among models, need to handle more cases in the future
493-
file_path = dir_model / "preprocessor_config.json"
494-
if os.path.exists(file_path):
495-
with open(file_path, "r", encoding="utf-8") as f:
496-
return json.load(f)
493+
if isinstance(dir_or_model_id, Path):
494+
file_path = dir_or_model_id / "preprocessor_config.json"
495+
if os.path.exists(file_path):
496+
with open(file_path, "r", encoding="utf-8") as f:
497+
return json.load(f)
498+
else:
499+
raise Exception(f"Preprocessor config not found at {file_path}")
497500
else:
498-
return None
501+
return AutoImageProcessor.from_pretrained(dir_or_model_id).to_dict()
499502

500503
@classmethod
501504
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
@@ -1586,16 +1589,31 @@ def prepare_tensors(self):
15861589
raise ValueError(f"Unprocessed norms: {norms}")
15871590

15881591

1589-
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration")
1592+
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM")
15901593
class LlamaModel(Model):
15911594
model_arch = gguf.MODEL_ARCH.LLAMA
15921595

15931596
def __init__(self, *args, **kwargs):
15941597
super().__init__(*args, **kwargs)
1595-
if "vision_config" in self.hparams:
1598+
1599+
model_type = self.hparams.get("model_type", None)
1600+
self.vision_arch = None
1601+
1602+
# only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf
1603+
if "vision_config" in self.hparams and model_type == "llava":
15961604
self.vparams = self.hparams["vision_config"]
1597-
if self.vparams is not None:
1598-
self.v_tensor_map = gguf.get_tensor_name_map(gguf.MODEL_ARCH.LLAVA_VISION, self.vparams["num_hidden_layers"])
1605+
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
1606+
self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA
1607+
1608+
# only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B
1609+
if "mm_vision_tower" in self.hparams and model_type == "mobilevlm":
1610+
vision_model_id = self.hparams["mm_vision_tower"]
1611+
self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"]
1612+
self.preprocessor_config = self.load_preprocessor_config(vision_model_id)
1613+
self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM
1614+
1615+
if self.vparams is not None and self.vision_arch is not None:
1616+
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])
15991617

16001618
def set_vocab(self):
16011619
try:
@@ -1631,23 +1649,31 @@ def set_vocab(self):
16311649
self.gguf_writer.add_add_bos_token(False)
16321650

16331651
# For vision model
1634-
if self.vparams is not None and self.preprocessor_config is not None:
1652+
if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
16351653
self.gguf_writer.add_vision_type("clip-vit")
16361654
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
16371655
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
1638-
self.gguf_writer.add_vision_clip_architecture("llava")
1656+
self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
16391657
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
16401658
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
16411659
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
16421660
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
16431661
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
16441662
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
1645-
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
16461663
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
16471664
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
16481665
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)
1666+
if "vision_feature_layer" in self.hparams:
1667+
self.gguf_writer.add_vision_clip_select_layer(self.hparams["vision_feature_layer"])
1668+
elif "mm_vision_select_layer" in self.hparams:
1669+
self.gguf_writer.add_vision_clip_select_layer(self.hparams["mm_vision_select_layer"])
1670+
else:
1671+
raise ValueError("gguf: can not find vision_feature_layer parameter.")
16491672
# TODO: should not hardcode these, but they are currently missing from config.json
1650-
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
1673+
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
1674+
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
1675+
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
1676+
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
16511677
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
16521678

16531679
def set_gguf_parameters(self):
@@ -1683,6 +1709,8 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
16831709
# For vision model
16841710
if name.startswith("language_model"):
16851711
name = name.replace("language_model.", "")
1712+
else:
1713+
name = name.replace("model.vision_tower.", "")
16861714
if "post_layernorm" in name:
16871715
return [] # skip post_layernorm
16881716

@@ -2101,7 +2129,7 @@ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims:
21012129
return n_dims > 1
21022130

21032131

2104-
@Model.register("MiniCPMForCausalLM")
2132+
@Model.register("MiniCPMForCausalLM", "MiniCPMV")
21052133
class MiniCPMModel(Model):
21062134
model_arch = gguf.MODEL_ARCH.MINICPM
21072135

gguf-py/gguf/constants.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,8 @@ class MODEL_ARCH(IntEnum):
308308
CHAMELEON = auto()
309309
WAVTOKENIZER_DEC = auto()
310310
# vision models
311-
LLAVA_VISION = auto()
311+
VISION_LLAVA = auto()
312+
VISION_MOBILEVLM = auto()
312313

313314

314315
class MODEL_TENSOR(IntEnum):
@@ -439,6 +440,8 @@ class MODEL_TENSOR(IntEnum):
439440
POSNET_ATTN_OUT = auto()
440441
# vision
441442
V_MMPROJ = auto()
443+
V_MMPROJ_MLP = auto()
444+
V_MMPROJ_PEG = auto()
442445
V_ENC_EMBD_CLS = auto()
443446
V_ENC_EMBD_PATCH = auto()
444447
V_ENC_EMBD_POS = auto()
@@ -512,6 +515,9 @@ class MODEL_TENSOR(IntEnum):
512515
MODEL_ARCH.GRANITE_MOE: "granitemoe",
513516
MODEL_ARCH.CHAMELEON: "chameleon",
514517
MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
518+
# vision
519+
MODEL_ARCH.VISION_LLAVA: "llava",
520+
MODEL_ARCH.VISION_MOBILEVLM: "mobilevlm",
515521
}
516522

517523
TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
@@ -641,6 +647,8 @@ class MODEL_TENSOR(IntEnum):
641647
MODEL_TENSOR.POSNET_ATTN_OUT: "posnet.{bid}.attn_output",
642648
# vision
643649
MODEL_TENSOR.V_MMPROJ: "v.mmproj_{bid}",
650+
MODEL_TENSOR.V_MMPROJ_MLP: "v.mmproj.mlp.{bid}",
651+
MODEL_TENSOR.V_MMPROJ_PEG: "v.mmproj.peg.{bid}",
644652
MODEL_TENSOR.V_ENC_EMBD_CLS: "v.enc.embd.cls",
645653
MODEL_TENSOR.V_ENC_EMBD_PATCH: "v.enc.embd.patch",
646654
MODEL_TENSOR.V_ENC_EMBD_POS: "v.enc.embd.pos",
@@ -1595,7 +1603,7 @@ class MODEL_TENSOR(IntEnum):
15951603
MODEL_TENSOR.POSNET_ATTN_V,
15961604
MODEL_TENSOR.POSNET_ATTN_OUT,
15971605
],
1598-
MODEL_ARCH.LLAVA_VISION: [
1606+
MODEL_ARCH.VISION_LLAVA: [
15991607
MODEL_TENSOR.V_MMPROJ,
16001608
MODEL_TENSOR.V_ENC_EMBD_CLS,
16011609
MODEL_TENSOR.V_ENC_EMBD_PATCH,
@@ -1611,6 +1619,23 @@ class MODEL_TENSOR(IntEnum):
16111619
MODEL_TENSOR.V_PRE_NORM,
16121620
MODEL_TENSOR.V_POST_NORM,
16131621
],
1622+
MODEL_ARCH.VISION_MOBILEVLM: [
1623+
MODEL_TENSOR.V_MMPROJ_MLP,
1624+
MODEL_TENSOR.V_MMPROJ_PEG,
1625+
MODEL_TENSOR.V_ENC_EMBD_CLS,
1626+
MODEL_TENSOR.V_ENC_EMBD_PATCH,
1627+
MODEL_TENSOR.V_ENC_EMBD_POS,
1628+
MODEL_TENSOR.V_ENC_ATTN_Q,
1629+
MODEL_TENSOR.V_ENC_ATTN_K,
1630+
MODEL_TENSOR.V_ENC_ATTN_V,
1631+
MODEL_TENSOR.V_ENC_INPUT_NORM,
1632+
MODEL_TENSOR.V_ENC_OUTPUT,
1633+
MODEL_TENSOR.V_ENC_OUTPUT_NORM,
1634+
MODEL_TENSOR.V_ENC_FFN_UP,
1635+
MODEL_TENSOR.V_ENC_FFN_DOWN,
1636+
MODEL_TENSOR.V_PRE_NORM,
1637+
MODEL_TENSOR.V_POST_NORM,
1638+
],
16141639
# TODO
16151640
}
16161641

@@ -1693,11 +1718,12 @@ class PoolingType(IntEnum):
16931718

16941719

16951720
class CLIPProjectorType(Enum):
1696-
MLP = 'mlp'
1721+
MLP = 'mlp'
1722+
LDPV2 = 'ldpv2'
16971723

16981724

16991725
class CLIPPatchMergeType(Enum):
1700-
FLAT = 'flat'
1726+
FLAT = 'flat'
17011727
SPATIAL_UNPAD = 'spatial_unpad'
17021728

17031729

gguf-py/gguf/gguf_writer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -876,7 +876,7 @@ def add_remove_extra_whitespaces(self, value: bool) -> None:
876876

877877
def add_precompiled_charsmap(self, charsmap: Sequence[bytes]) -> None:
878878
self.add_array(Keys.Tokenizer.PRECOMPILED_CHARSMAP, charsmap)
879-
879+
880880
def add_vision_type(self, value: str) -> None:
881881
self.add_string(Keys.Vision.TYPE, value)
882882

gguf-py/gguf/tensor_mapping.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,14 @@ class TensorNameMap:
794794
"multi_modal_projector.linear_{bid}",
795795
),
796796

797+
MODEL_TENSOR.V_MMPROJ_MLP: (
798+
"model.mm_projector.mlp.mlp.{bid}",
799+
),
800+
801+
MODEL_TENSOR.V_MMPROJ_PEG: (
802+
"model.mm_projector.peg.peg.{bid}",
803+
),
804+
797805
MODEL_TENSOR.V_ENC_EMBD_CLS: (
798806
"vision_tower.vision_model.embeddings.class_embedding",
799807
),

src/llama-arch.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6767

6868
static const std::map<vision_arch, const char *> VISION_ARCH_NAMES = {
6969
{ VISION_ARCH_LLAVA, "llava" },
70+
{ VISION_ARCH_MOBILEVLM, "mobilevlm" },
7071
{ VISION_ARCH_UNKNOWN, "(unknown)" },
7172
};
7273

@@ -1345,7 +1346,27 @@ static const std::map<vision_arch, std::map<vision_tensor, const char *>> VISION
13451346
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
13461347
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
13471348
}
1348-
}
1349+
},
1350+
{
1351+
VISION_ARCH_MOBILEVLM,
1352+
{
1353+
{ VISION_TENSOR_MMPROJ_MLP, "v.mmproj.mlp.%d" },
1354+
{ VISION_TENSOR_MMPROJ_PEG, "v.mmproj.peg.%d" },
1355+
{ VISION_TENSOR_ENC_EMBD_CLS, "v.enc.embd.cls" },
1356+
{ VISION_TENSOR_ENC_EMBD_PATCH, "v.enc.embd.patch" },
1357+
{ VISION_TENSOR_ENC_EMBD_POS, "v.enc.embd.pos" },
1358+
{ VISION_TENSOR_ENC_ATTN_Q, "v.enc.blk.%d.attn_q" },
1359+
{ VISION_TENSOR_ENC_ATTN_K, "v.enc.blk.%d.attn_k" },
1360+
{ VISION_TENSOR_ENC_ATTN_V, "v.enc.blk.%d.attn_v" },
1361+
{ VISION_TENSOR_ENC_INPUT_NORM, "v.enc.blk.%d.input_norm" },
1362+
{ VISION_TENSOR_ENC_OUTPUT, "v.enc.blk.%d.output" },
1363+
{ VISION_TENSOR_ENC_OUTPUT_NORM, "v.enc.blk.%d.output_norm" },
1364+
{ VISION_TENSOR_ENC_FFN_UP, "v.enc.blk.%d.ffn_up" },
1365+
{ VISION_TENSOR_ENC_FFN_DOWN, "v.enc.blk.%d.ffn_down" },
1366+
{ VISION_TENSOR_PRE_NORM, "v.pre_norm" },
1367+
{ VISION_TENSOR_POST_NORM, "v.post_norm" },
1368+
}
1369+
},
13491370
};
13501371

13511372
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@@ -1499,6 +1520,10 @@ std::string LLM_KV::operator()(llm_kv kv) const {
14991520

15001521
template<>
15011522
std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
1523+
if (LLM_TENSOR_NAMES.find(arch) == LLM_TENSOR_NAMES.end()) {
1524+
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
1525+
}
1526+
15021527
if (LLM_TENSOR_NAMES.at(arch).find(tensor) == LLM_TENSOR_NAMES.at(arch).end()) {
15031528
return "__missing__";
15041529
}
@@ -1515,6 +1540,10 @@ std::string BASE_TN_IMPL<llm_arch, llm_tensor>::str() const {
15151540

15161541
template<>
15171542
std::string BASE_TN_IMPL<vision_arch, vision_tensor>::str() const {
1543+
if (VISION_TENSOR_NAMES.find(arch) == VISION_TENSOR_NAMES.end()) {
1544+
throw std::runtime_error(format("Cannot find tensor name mapping for arch %d", arch));
1545+
}
1546+
15181547
if (VISION_TENSOR_NAMES.at(arch).find(tensor) == VISION_TENSOR_NAMES.at(arch).end()) {
15191548
return "__missing__";
15201549
}

src/llama-arch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ enum llm_arch {
7272
enum vision_arch {
7373
VISION_ARCH_UNKNOWN,
7474
VISION_ARCH_LLAVA,
75+
VISION_ARCH_MOBILEVLM,
7576
};
7677

7778
enum llm_kv {
@@ -356,6 +357,8 @@ enum llm_tensor {
356357

357358
enum vision_tensor {
358359
VISION_TENSOR_MMPROJ,
360+
VISION_TENSOR_MMPROJ_MLP,
361+
VISION_TENSOR_MMPROJ_PEG,
359362
VISION_TENSOR_ENC_EMBD_CLS,
360363
VISION_TENSOR_ENC_EMBD_PATCH,
361364
VISION_TENSOR_ENC_EMBD_POS,

0 commit comments

Comments
 (0)