Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
2a458d1
wip
ngxson Jan 18, 2025
0a81051
llama : second attempt to refactor vision API
ngxson Jan 18, 2025
6cabdda
add back convert hf to gguf
ngxson Jan 18, 2025
d0068ef
add mobilevlm
ngxson Jan 19, 2025
4a7ab89
wip minicpmv
ngxson Jan 19, 2025
431bb08
change gguf KV from clip to vit
ngxson Jan 21, 2025
bd0714b
reuse LLM_ARCH and LLM_TENSOR
ngxson Jan 21, 2025
ad38e87
rename everywhere
ngxson Jan 21, 2025
32daa38
Merge branch 'master' into xsn/vision_2
ngxson Jan 22, 2025
9716c7b
temporary refactor llama_vision_graph_builder
ngxson Jan 22, 2025
ba489b4
wip minicpmv
ngxson Jan 22, 2025
c0d93dd
minicpmv works but missing uhd slices
ngxson Jan 22, 2025
8586d23
minicpm working without uhd
ngxson Jan 23, 2025
25a97ce
correct positions for siglip
ngxson Jan 23, 2025
c3a654c
add SmolVLM
ngxson Jan 23, 2025
b986af8
py: a bit cleaner
ngxson Jan 23, 2025
b72d755
Merge branch 'master' into xsn/vision_2
ngxson Jan 23, 2025
0959cc1
Merge branch 'master' into xsn/vision_2
ngxson Jan 25, 2025
90eefc2
refactor minicpm-v support
ngxson Jan 25, 2025
e884d3d
Merge branch 'master' into xsn/vision_2
ngxson Feb 2, 2025
ff77b15
Merge branch 'master' into xsn/vision_2
ngxson Feb 6, 2025
fa55281
separate vision ctx and llm ctx
ngxson Feb 6, 2025
0ec6bce
Merge branch 'master' into xsn/vision_2
ngxson Mar 1, 2025
7863232
clarify
ngxson Mar 1, 2025
c4e9231
fix smolVLM conversion
ngxson Mar 1, 2025
21aa2f5
phi-4-mm TEXT-ONLY for now
ngxson Mar 1, 2025
0ead9c4
Revert "fix smolVLM conversion"
ngxson Mar 2, 2025
45bc188
a bit cleaner for llava conversion
ngxson Mar 2, 2025
5283a15
Revert "phi-4-mm TEXT-ONLY for now"
ngxson Mar 2, 2025
424807e
Merge branch 'master' into xsn/vision_2
ngxson Mar 16, 2025
cee80d4
fix merge problem
ngxson Mar 16, 2025
cdff8c5
fix merge (2)
ngxson Mar 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1403,7 +1403,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, const std::string & value) {
params.image.emplace_back(value);
}
).set_examples({LLAMA_EXAMPLE_LLAVA}));
).set_examples({LLAMA_EXAMPLE_LLAVA, LLAMA_EXAMPLE_VISION}));
if (llama_supports_rpc()) {
add_opt(common_arg(
{"--rpc"}, "SERVERS",
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ enum llama_example {
LLAMA_EXAMPLE_LOOKUP,
LLAMA_EXAMPLE_PARALLEL,
LLAMA_EXAMPLE_TTS,
LLAMA_EXAMPLE_VISION,

LLAMA_EXAMPLE_COUNT,
};
Expand Down
182 changes: 169 additions & 13 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import TYPE_CHECKING, Any, Callable, ContextManager, Iterable, Iterator, Literal, Sequence, TypeVar, cast
from itertools import chain

from transformers import AutoConfig
import math
import numpy as np
import torch
Expand Down Expand Up @@ -66,6 +67,13 @@ class Model:
metadata_override: Path | None
dir_model_card: Path

# for vision model
vision_arch: gguf.MODEL_ARCH | None = None
preprocessor_config: dict[str, Any] | None = None
vparams: dict[str, Any] | None = None
v_tensor_map: gguf.TensorNameMap | None = None
v_tensor_names: set[str] | None

# subclasses should define this!
model_arch: gguf.MODEL_ARCH

Expand Down Expand Up @@ -126,6 +134,16 @@ def find_hparam(self, keys: Iterable[str], optional: bool = False) -> Any:
return None
raise KeyError(f"could not find any of: {keys}")

def find_vparams(self, keys: Iterable[str], optional: bool = False) -> Any:
if self.vparams is None:
raise ValueError("vision model parameters not set")
key = next((k for k in keys if k in self.vparams), None)
if key is not None:
return self.vparams[key]
if optional:
return None
raise KeyError(f"(vision) could not find any of: {keys}")

def set_vocab(self):
self._set_vocab_gpt2()

Expand Down Expand Up @@ -210,9 +228,13 @@ def match_model_tensor_name(self, name: str, key: gguf.MODEL_TENSOR, bid: int |

def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
if new_name is None:
new_name_vision = self.v_tensor_map.get_name(key=name, try_suffixes=try_suffixes) if self.v_tensor_map is not None else None
if new_name is not None:
return new_name
elif new_name_vision is not None:
return new_name_vision
else:
raise ValueError(f"Can not map tensor {name!r}")
return new_name

def set_gguf_parameters(self):
self.gguf_writer.add_block_count(self.block_count)
Expand Down Expand Up @@ -257,6 +279,20 @@ def set_gguf_parameters(self):
self.gguf_writer.add_key_length(head_dim)
self.gguf_writer.add_value_length(head_dim)

# Vision model parameters
if self.vparams is not None and self.preprocessor_config is not None and self.vision_arch is not None:
self.gguf_writer.add_vision_type("clip-vit")
self.gguf_writer.add_vision_image_size(self.vparams["image_size"])
self.gguf_writer.add_vision_patch_size(self.vparams["patch_size"])
self.gguf_writer.add_vision_clip_architecture(gguf.MODEL_ARCH_NAMES[self.vision_arch])
self.gguf_writer.add_vision_clip_block_count(self.vparams["num_hidden_layers"])
self.gguf_writer.add_vision_clip_embedding_length(self.vparams["hidden_size"])
self.gguf_writer.add_vision_clip_feed_forward_length(self.vparams["intermediate_size"])
self.gguf_writer.add_vision_clip_head_count(self.vparams["num_attention_heads"])
self.gguf_writer.add_vision_clip_image_mean(self.preprocessor_config["image_mean"])
self.gguf_writer.add_vision_clip_image_std(self.preprocessor_config["image_std"])
self.gguf_writer.add_vision_clip_select_layer(self.find_hparam(["vision_feature_layer", "mm_vision_select_layer"]))

self.gguf_writer.add_file_type(self.ftype)
logger.info(f"gguf: file type = {self.ftype}")

Expand Down Expand Up @@ -466,7 +502,24 @@ def get_model_part_names(dir_model: Path, prefix: str, suffix: str) -> list[str]
@staticmethod
def load_hparams(dir_model: Path):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
hparams = json.load(f)
if "text_config" in hparams:
text_config = hparams["text_config"]
# for example, llava-1.5-7b-hf misses the language model config, need to retrieve it via model ID
if "_name_or_path" in text_config:
text_config = AutoConfig.from_pretrained(text_config["_name_or_path"]).to_dict()
hparams = {**text_config, **hparams}
return hparams

@staticmethod
def load_preprocessor_config(dir_model: Path):
# TODO: this varies vastly among models, need to handle more cases in the future
file_path = dir_model / "preprocessor_config.json"
if os.path.exists(file_path):
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)
else:
raise Exception(f"Preprocessor config not found at {file_path}")

@classmethod
def register(cls, *names: str) -> Callable[[AnyModel], AnyModel]:
Expand Down Expand Up @@ -519,7 +572,9 @@ def get_vocab_base(self) -> tuple[list[str], list[int], str]:
toktypes: list[int] = []

from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
# DEBIAN_FRONTEND=noninteractive means that the script is running in a non-interactive environment (i.e. CI), so we cannot answer Y/N when it asks for user input
is_cli_non_interactive = os.environ.get("DEBIAN_FRONTEND", "") == "noninteractive"
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=is_cli_non_interactive)
vocab_size = self.hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size

Expand Down Expand Up @@ -1557,10 +1612,33 @@ def prepare_tensors(self):
raise ValueError(f"Unprocessed norms: {norms}")


@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM")
@Model.register("LLaMAForCausalLM", "LlamaForCausalLM", "MistralForCausalLM", "MixtralForCausalLM", "LlavaForConditionalGeneration", "MobileLlamaForCausalLM")
class LlamaModel(Model):
model_arch = gguf.MODEL_ARCH.LLAMA

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

model_type = self.hparams.get("model_type", None)
self.vision_arch = None

# only tested with https://huggingface.co/llava-hf/llava-1.5-7b-hf
if "vision_config" in self.hparams and model_type == "llava":
self.vparams = self.hparams["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
self.vision_arch = gguf.MODEL_ARCH.VISION_LLAVA

# only tested with https://huggingface.co/mtgv/MobileVLM_V2-1.7B
if "mm_vision_tower" in self.hparams and model_type == "mobilevlm":
from transformers import AutoImageProcessor
vision_model_id = self.hparams["mm_vision_tower"]
self.vparams = AutoConfig.from_pretrained(vision_model_id).to_dict()["vision_config"]
self.preprocessor_config = AutoImageProcessor.from_pretrained(vision_model_id).to_dict()
self.vision_arch = gguf.MODEL_ARCH.VISION_MOBILEVLM

if self.vparams is not None and self.vision_arch is not None:
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])

def set_vocab(self):
try:
self._set_vocab_sentencepiece()
Expand Down Expand Up @@ -1610,6 +1688,18 @@ def set_gguf_parameters(self):
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR)
self.gguf_writer.add_rope_scaling_factor(self.hparams["rope_scaling"]["factor"])

# For vision model
if self.vparams is not None:
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
# TODO: should not hardcode these, but they are currently missing from config.json
if self.vision_arch == gguf.MODEL_ARCH.VISION_LLAVA:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.MLP)
if self.vision_arch == gguf.MODEL_ARCH.VISION_MOBILEVLM:
self.gguf_writer.add_vision_clip_projector_type(gguf.constants.CLIPProjectorType.LDPV2)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-05)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2 + 1
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)

@staticmethod
def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
if n_head_kv is not None and n_head != n_head_kv:
Expand All @@ -1624,6 +1714,14 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

# For vision model
if name.startswith("language_model"):
name = name.replace("language_model.", "")
else:
name = name.replace("model.vision_tower.", "")
if "post_layernorm" in name:
return [] # skip post_layernorm

if name.endswith(("q_proj.weight", "q_proj.bias")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight", "k_proj.bias")):
Expand Down Expand Up @@ -2039,26 +2137,69 @@ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims:
return n_dims > 1


@Model.register("MiniCPMForCausalLM")
@Model.register("MiniCPMForCausalLM", "MiniCPMV")
class MiniCPMModel(Model):
model_arch = gguf.MODEL_ARCH.MINICPM
proj_type: gguf.constants.CLIPProjectorType | None

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

model_type = self.hparams.get("model_type", None)

# only tested with https://huggingface.co/openbmb/MiniCPM-V-2_6
if "vision_config" in self.hparams and model_type == "minicpmv":
self.vparams = self.hparams["vision_config"]
self.preprocessor_config = self.load_preprocessor_config(self.dir_model)
self.vision_arch = gguf.MODEL_ARCH.VISION_MINICPMV
version = str(self.hparams.get("version", "unknown"))
if version == "2.5":
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_5
elif version == "2.6":
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
else:
raise ValueError(f"Unsupported MiniCPM-V version: {version}")

if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
self.preprocessor_config["image_std"] = [0.5, 0.5, 0.5]
self.hparams["vision_feature_layer"] = 0
self.v_tensor_map = gguf.get_tensor_name_map(self.vision_arch, self.vparams["num_hidden_layers"])

def set_gguf_parameters(self):
super().set_gguf_parameters()
embedding_scale = float(self.hparams["scale_emb"])
# scale_emb
embedding_scale = float(self.hparams.get("scale_emb", 1.0))
self.gguf_writer.add_embedding_scale(embedding_scale)
logger.info(f"gguf: (minicpm) embedding_scale = {embedding_scale}")
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
# scale_depth
if "scale_depth" in self.hparams:
residual_scale = self.hparams["scale_depth"] / self.hparams["num_hidden_layers"] ** 0.5
else:
residual_scale = 1.0
self.gguf_writer.add_residual_scale(residual_scale)
logger.info(f"gguf: (minicpm) residual_scale = {residual_scale}")
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
# logit_scale
if "dim_model_base" in self.hparams:
logit_scale = self.hparams["hidden_size"] / self.hparams["dim_model_base"]
else:
logit_scale = 1.0
self.gguf_writer.add_logit_scale(logit_scale)
logger.info(f"gguf: (minicpm) logit_scale = {logit_scale}")
if self.hparams.get("rope_scaling") is not None:
if self.hparams["rope_scaling"].get("type") == "longrope":
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LONGROPE)
logger.info(f"gguf: (minicpm) rope_scaling_type = {gguf.RopeScalingType.LONGROPE}")

# For vision model
if self.vparams is not None and self.proj_type is not None:
self.gguf_writer.add_vision_clip_patch_merge_type(gguf.CLIPPatchMergeType.FLAT)
self.gguf_writer.add_vision_clip_projector_type(self.proj_type)
self.gguf_writer.add_vision_clip_layer_norm_epsilon(1e-06)
max_pos_embd = (self.vparams["image_size"] // self.vparams["patch_size"])**2
self.gguf_writer.add_vision_clip_max_position_embeddings(max_pos_embd)


def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]

Expand All @@ -2077,18 +2218,33 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))

def set_vocab(self):
self._set_vocab_sentencepiece()
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
# undocumented anywhere, I only found this thanks to https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf
self._set_vocab_gpt2()
else:
self._set_vocab_sentencepiece()

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused

# For vision model
if name.startswith("llm."):
name = name.replace("llm.", "")
# attention, someone mess up and use underscore instead of dot
if name.endswith("in_proj_weight"):
name = name.replace("_weight", ".weight")
if name.endswith("in_proj_bias"):
name = name.replace("_bias", ".bias")
if "post_layernorm" in name:
return [] # skip post_layernorm

n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")

# HF models permute some of the tensors, so we need to undo that
if name.endswith(("q_proj.weight")):
if not name.startswith("vpm") and name.endswith(("q_proj.weight")):
data_torch = LlamaModel.permute(data_torch, n_head, n_head)
if name.endswith(("k_proj.weight")):
if not name.startswith("vpm") and name.endswith(("k_proj.weight")):
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)

return [(self.map_tensor_name(name), data_torch)]
Expand Down Expand Up @@ -4974,7 +5130,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None):

def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file")
description="Convert a huggingface model to a GGML compatible file\n\nNote: When converting vision models, this script may use internet connection to download configuration files via Hugging Face.")
parser.add_argument(
"--vocab-only", action="store_true",
help="extract only the vocab",
Expand Down
1 change: 1 addition & 0 deletions examples/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ else()
add_subdirectory(tokenize)
add_subdirectory(tts)
add_subdirectory(gen-docs)
add_subdirectory(vision)
if (NOT GGML_BACKEND_DL)
# these examples use the backends directly and cannot be built with dynamic loading
add_subdirectory(convert-llama2c-to-ggml)
Expand Down
1 change: 1 addition & 0 deletions examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2949,6 +2949,7 @@ struct server_context {
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
nullptr,
};

const int ret = llama_decode(ctx, batch_view);
Expand Down
5 changes: 5 additions & 0 deletions examples/vision/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
set(TARGET llama-vision)
add_executable(${TARGET} vision.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17)
3 changes: 3 additions & 0 deletions examples/vision/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# llama.cpp/example/simple-vision

Minimal demo for vision API
Loading
Loading