Skip to content

Commit 46c3d25

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents 926d4dc + f5e96b3 commit 46c3d25

40 files changed

+28385
-269
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: Update Operations Documentation
2+
3+
on:
4+
push:
5+
paths:
6+
- 'docs/ops/**'
7+
- 'scripts/create_ops_docs.py'
8+
pull_request:
9+
paths:
10+
- 'docs/ops/**'
11+
- 'scripts/create_ops_docs.py'
12+
13+
jobs:
14+
update-ops-docs:
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- name: Checkout repository
19+
uses: actions/checkout@v4
20+
21+
- name: Set up Python
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: '3.x'
25+
26+
- name: Generate operations documentation to temporary file
27+
run: |
28+
mkdir -p /tmp/ops_check
29+
./scripts/create_ops_docs.py /tmp/ops_check/ops.md
30+
31+
- name: Check if docs/ops.md matches generated version
32+
run: |
33+
if ! diff -q docs/ops.md /tmp/ops_check/ops.md; then
34+
echo "Operations documentation (docs/ops.md) is not up to date with the backend CSV files."
35+
echo "To fix: run ./scripts/create_ops_docs.py and commit the updated docs/ops.md along with your changes"
36+
echo "Differences found:"
37+
diff docs/ops.md /tmp/ops_check/ops.md || true
38+
exit 1
39+
fi
40+
echo "Operations documentation is up to date."

README.md

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,9 @@
66
[![Release](https://img.shields.io/github/v/release/ggml-org/llama.cpp)](https://github.com/ggml-org/llama.cpp/releases)
77
[![Server](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml/badge.svg)](https://github.com/ggml-org/llama.cpp/actions/workflows/server.yml)
88

9-
[Roadmap](https://github.com/users/ggerganov/projects/7) / [Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml)
9+
[Manifesto](https://github.com/ggml-org/llama.cpp/discussions/205) / [ggml](https://github.com/ggml-org/ggml) / [ops](https://github.com/ggml-org/llama.cpp/blob/master/docs/ops.md)
1010

11-
Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others) in pure C/C++
11+
LLM inference in C/C++
1212

1313
## Recent API changes
1414

@@ -17,10 +17,9 @@ Inference of Meta's [LLaMA](https://arxiv.org/abs/2302.13971) model (and others)
1717

1818
## Hot topics
1919

20-
- 🔥 Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
21-
- A new binary `llama-mtmd-cli` is introduced to replace `llava-cli`, `minicpmv-cli`, `gemma3-cli` ([#13012](https://github.com/ggml-org/llama.cpp/pull/13012)) and `qwen2vl-cli` ([#13141](https://github.com/ggml-org/llama.cpp/pull/13141)), `libllava` will be deprecated
20+
- Hot PRs: [All](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+) | [Open](https://github.com/ggml-org/llama.cpp/pulls?q=is%3Apr+label%3Ahot+is%3Aopen)
21+
- Multimodal support arrived in `llama-server`: [#12898](https://github.com/ggml-org/llama.cpp/pull/12898) | [documentation](./docs/multimodal.md)
2222
- VS Code extension for FIM completions: https://github.com/ggml-org/llama.vscode
23-
- Universal [tool call support](./docs/function-calling.md) in `llama-server` https://github.com/ggml-org/llama.cpp/pull/9639
2423
- Vim/Neovim plugin for FIM completions: https://github.com/ggml-org/llama.vim
2524
- Introducing GGUF-my-LoRA https://github.com/ggml-org/llama.cpp/discussions/10123
2625
- Hugging Face Inference Endpoints now support GGUF out of the box! https://github.com/ggml-org/llama.cpp/discussions/9669

common/CMakeLists.txt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,7 @@ if (LLAMA_CURL)
8686
endif()
8787
target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL)
8888
include_directories(${CURL_INCLUDE_DIRS})
89-
find_library(CURL_LIBRARY curl REQUIRED)
90-
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARY})
89+
set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} ${CURL_LIBRARIES})
9190
endif ()
9291

9392
if (LLAMA_LLGUIDANCE)
@@ -112,13 +111,13 @@ if (LLAMA_LLGUIDANCE)
112111

113112
ExternalProject_Add(llguidance_ext
114113
GIT_REPOSITORY https://github.com/guidance-ai/llguidance
115-
# v0.7.20 (+ fix to build on GCC 15):
116-
GIT_TAG b5b8b64dba11c4e4ee6b1d1450d3a3ae279891e8
114+
# v1.0.1:
115+
GIT_TAG d795912fedc7d393de740177ea9ea761e7905774
117116
PREFIX ${CMAKE_BINARY_DIR}/llguidance
118117
SOURCE_DIR ${LLGUIDANCE_SRC}
119118
BUILD_IN_SOURCE TRUE
120119
CONFIGURE_COMMAND ""
121-
BUILD_COMMAND cargo build --release
120+
BUILD_COMMAND cargo build --release --package llguidance
122121
INSTALL_COMMAND ""
123122
BUILD_BYPRODUCTS ${LLGUIDANCE_PATH}/${LLGUIDANCE_LIB_NAME} ${LLGUIDANCE_PATH}/llguidance.h
124123
UPDATE_COMMAND ""

convert_hf_to_gguf.py

Lines changed: 197 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def prepare_tensors(self):
300300
gguf.MODEL_TENSOR.POS_EMBD,
301301
gguf.MODEL_TENSOR.TOKEN_TYPES,
302302
gguf.MODEL_TENSOR.SSM_CONV1D,
303+
gguf.MODEL_TENSOR.SHORTCONV_CONV,
303304
gguf.MODEL_TENSOR.TIME_MIX_FIRST,
304305
gguf.MODEL_TENSOR.TIME_MIX_W1,
305306
gguf.MODEL_TENSOR.TIME_MIX_W2,
@@ -833,6 +834,12 @@ def get_vocab_base_pre(self, tokenizer) -> str:
833834
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
834835
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
835836
res = "falcon-h1"
837+
if chkhsh == "f6791d196f87ce6b56a7d234be618e0d58f8cda3549416635b2bebcd22cd95c4":
838+
# ref: https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct
839+
res = "midm-2.0"
840+
if chkhsh == "169bf0296a13c4d9b7672313f749eb36501d931022de052aad6e36f2bf34dd51":
841+
# ref: https://huggingface.co/LiquidAI/LFM2-Tokenizer
842+
res = "lfm2"
836843

837844
if res is None:
838845
logger.warning("\n")
@@ -4890,6 +4897,9 @@ def __init__(self, dir_model: Path, *args, **kwargs):
48904897
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
48914898
hparams = json.load(f)
48924899
super().__init__(dir_model, *args, hparams=hparams, **kwargs)
4900+
self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4901+
self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model
4902+
self.n_group = self.find_hparam(["n_groups"], optional=True) or 1
48934903

48944904
def set_vocab(self):
48954905
vocab_size = self.hparams["vocab_size"]
@@ -4912,32 +4922,29 @@ def set_vocab(self):
49124922
self._set_vocab_builtin("gpt-neox", vocab_size)
49134923

49144924
def set_gguf_parameters(self):
4915-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4916-
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4917-
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
4918-
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4919-
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
4920-
n_group = self.find_hparam(["n_groups"], optional=True) or 1
4925+
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
4926+
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
4927+
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
49214928

49224929
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
49234930

49244931
# Fail early for models which don't have a block expansion factor of 2
49254932
# TODO: does this really matter?
49264933
# skip the assertion for FalconH1 Model
49274934
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
4928-
assert d_inner == 2 * d_model
4929-
assert d_inner % head_dim == 0
4935+
assert self.d_inner == 2 * self.d_model
4936+
assert self.d_inner % head_dim == 0
49304937

49314938
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
4932-
self.gguf_writer.add_embedding_length(d_model)
4939+
self.gguf_writer.add_embedding_length(self.d_model)
49334940
self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading
49344941
self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading
49354942
self.gguf_writer.add_block_count(self.block_count)
49364943
self.gguf_writer.add_ssm_conv_kernel(d_conv)
4937-
self.gguf_writer.add_ssm_inner_size(d_inner)
4944+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
49384945
self.gguf_writer.add_ssm_state_size(d_state)
4939-
self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim)
4940-
self.gguf_writer.add_ssm_group_count(n_group)
4946+
self.gguf_writer.add_ssm_time_step_rank(self.d_inner // head_dim)
4947+
self.gguf_writer.add_ssm_group_count(self.n_group)
49414948
self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps)
49424949
self.gguf_writer.add_file_type(self.ftype)
49434950

@@ -4962,10 +4969,7 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
49624969
# (D is also unsqueezed, but for more straightforward broadcast internally)
49634970
data_torch = data_torch.reshape((*data_torch.shape, 1))
49644971
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
4965-
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
4966-
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
4967-
n_group = self.hparams.get("n_groups", 1)
4968-
data_torch = data_torch.reshape((n_group, d_inner // n_group))
4972+
data_torch = data_torch.reshape((self.n_group, self.d_inner // self.n_group))
49694973

49704974
if name.endswith(".A_log"):
49714975
logger.debug("A_log --> A ==> " + new_name)
@@ -6452,18 +6456,148 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
64526456
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), up),
64536457
]
64546458

6459+
has_experts = bool(self.hparams.get('num_local_experts'))
6460+
64556461
if name.endswith("shared_mlp.input_linear.weight"):
64566462
ffn_dim = self.hparams["shared_intermediate_size"]
64576463
assert data_torch.shape[-2] == 2 * ffn_dim, "Merged FFN tensor size must be 2 * shared_intermediate_size"
64586464
gate, up = data_torch.split(ffn_dim, dim=-2)
6465+
if has_experts:
6466+
return [
6467+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
6468+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
6469+
]
6470+
return [
6471+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), gate),
6472+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), up),
6473+
]
6474+
6475+
if not has_experts and name.endswith("shared_mlp.output_linear.weight"):
64596476
return [
6460-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_SHEXP, bid), gate),
6461-
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_SHEXP, bid), up),
6477+
(self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), data_torch)
64626478
]
64636479

64646480
return super().modify_tensors(data_torch, name, bid)
64656481

64666482

6483+
@ModelBase.register("GraniteMoeHybridForCausalLM", "BambaForCausalLM")
6484+
class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
6485+
"""GraniteHybrid is a hybrid SSM + Attention model that uses Mamba2 SSM
6486+
layers and optionally uses MoE w/ a shared expert"""
6487+
model_arch = gguf.MODEL_ARCH.GRANITE_HYBRID
6488+
undo_permute = True
6489+
6490+
def __init__(self, *args, **kwargs):
6491+
6492+
# Hybrid mamba models use a prefix for the mamba-specific params.
6493+
# TODO: Extend this if the prefix(es) need to be configurable
6494+
self.hparam_prefixes = ["mamba"]
6495+
6496+
super().__init__(*args, **kwargs)
6497+
6498+
# Lists of which layers use ssm vs attention
6499+
self._attn_layers = self.get_attn_layers()
6500+
self._ssm_layers = [
6501+
i for i in range(self.block_count)
6502+
if i not in self._attn_layers
6503+
]
6504+
6505+
# n_group and d_inner are used during reshape_tensors for mamba2
6506+
self.d_model = self.find_hparam(["hidden_size", "d_model"])
6507+
self.n_group = self.find_hparam(["n_groups"])
6508+
self.d_inner = self.find_hparam(["expand"]) * self.d_model
6509+
6510+
def get_attn_layers(self):
6511+
# Explicit list of layer type names
6512+
if layer_types := self.hparams.get("layer_types"):
6513+
return [
6514+
i for i, typ in enumerate(layer_types)
6515+
if typ == "attention"
6516+
]
6517+
6518+
# Layer types indicated by index or period
6519+
attn_layers = self.hparams.get("attn_layer_indices", [])
6520+
if not attn_layers:
6521+
attn_period = self.hparams.get("attn_layer_period")
6522+
assert attn_period, "Didn't find attn_layer_indices or attn_layer_period"
6523+
attn_offset = self.hparams.get("attn_layer_offset")
6524+
assert attn_offset is not None, "No attention layer offset set with attn_layer_period"
6525+
attn_layers = [
6526+
i for i in range(self.block_count)
6527+
if i % attn_period == attn_offset
6528+
]
6529+
return attn_layers
6530+
6531+
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
6532+
prefixed = []
6533+
for pfx in self.hparam_prefixes:
6534+
prefixed.extend(
6535+
"_".join([pfx, k])
6536+
for k in keys
6537+
)
6538+
keys = list(keys) + prefixed
6539+
return Mamba2Model.find_hparam(self, keys, *args, **kwargs)
6540+
6541+
def modify_tensors(
6542+
self, data_torch: Tensor, name: str, bid: int | None
6543+
) -> Iterable[tuple[str, Tensor]]:
6544+
if (
6545+
name.endswith("block_sparse_moe.input_linear.weight")
6546+
or "shared_mlp" in name
6547+
):
6548+
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
6549+
6550+
# Determine whether this is a mamba layer or an attention layer
6551+
if bid in self._ssm_layers:
6552+
return Mamba2Model.modify_tensors(self, data_torch, name, bid)
6553+
elif bid in self._attn_layers:
6554+
return GraniteMoeModel.modify_tensors(self, data_torch, name, bid)
6555+
return [(self.map_tensor_name(name), data_torch)]
6556+
6557+
def set_gguf_parameters(self):
6558+
"""This method merges params from both parents and some that are
6559+
specific to this model. The result is some duplication of how the params
6560+
get set. The following warnings are expected during conversion:
6561+
6562+
WARNING:Duplicated key name 'granitehybrid.attention.head_count_kv'
6563+
WARNING:Duplicated key name 'granitehybrid.context_length'
6564+
"""
6565+
GraniteMoeModel.set_gguf_parameters(self)
6566+
6567+
## Mamba mixer params ##
6568+
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
6569+
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
6570+
self.gguf_writer.add_ssm_group_count(self.n_group)
6571+
self.gguf_writer.add_ssm_inner_size(self.d_inner)
6572+
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
6573+
# in llama.cpp
6574+
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
6575+
6576+
## Attention params ##
6577+
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
6578+
head_count_kv_vec = [
6579+
head_count_kv if i in self._attn_layers else 0 for i in range(self.block_count)
6580+
]
6581+
if rope_dim := self.hparams.get("attn_rotary_emb"):
6582+
self.gguf_writer.add_rope_dimension_count(rope_dim)
6583+
self.gguf_writer.add_head_count_kv(head_count_kv_vec)
6584+
6585+
## If Bamba, use rope, otherwise don't
6586+
use_rope = "BambaForCausalLM" in self.hparams["architectures"]
6587+
self.gguf_writer.add_rope_scaling_finetuned(use_rope)
6588+
if not use_rope:
6589+
self.gguf_writer.add_context_length(2**20)
6590+
6591+
## Validation ##
6592+
d_head = self.find_hparam(["d_head"], optional=True) or 64
6593+
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
6594+
assert self.d_inner % d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {d_head}"
6595+
6596+
def set_vocab(self):
6597+
self.hparams["pad_vocab_size_multiple"] = 8
6598+
Mamba2Model.set_vocab(self)
6599+
6600+
64676601
@ModelBase.register("BailingMoeForCausalLM")
64686602
class BailingMoeModel(TextModel):
64696603
model_arch = gguf.MODEL_ARCH.BAILINGMOE
@@ -6687,7 +6821,7 @@ def __init__(self, *args, **kwargs):
66876821
# Use Llama conversion for attention
66886822
self._transformer_model_class = LlamaModel
66896823

6690-
# n_group and d_inner are used during reshape_tensors for mamaba2
6824+
# n_group and d_inner are used during reshape_tensors for mamba2
66916825
self.n_group = self.find_hparam(["n_groups"])
66926826
self.d_inner = self.find_hparam(["mamba_d_ssm"])
66936827
self.d_head = self.find_hparam(["d_head"])
@@ -6943,6 +7077,50 @@ def set_vocab(self):
69437077
chat_template = tokenizer.chat_template.replace("[:]", "")
69447078
self.gguf_writer.add_chat_template(chat_template)
69457079

7080+
7081+
@ModelBase.register("Lfm2ForCausalLM")
7082+
@ModelBase.register("LFM2ForCausalLM")
7083+
class LFM2Model(TextModel):
7084+
model_arch = gguf.MODEL_ARCH.LFM2
7085+
7086+
def _add_feed_forward_length(self):
7087+
ff_dim = self.hparams["block_ff_dim"]
7088+
7089+
auto_adjust_ff_dim = self.hparams["block_auto_adjust_ff_dim"]
7090+
ff_dim = self.hparams["block_ff_dim"]
7091+
ffn_dim_multiplier = self.hparams["block_ffn_dim_multiplier"]
7092+
multiple_of = self.hparams["block_multiple_of"]
7093+
7094+
if auto_adjust_ff_dim:
7095+
ff_dim = int(2 * ff_dim / 3)
7096+
# custom dim factor multiplier
7097+
if ffn_dim_multiplier is not None:
7098+
ff_dim = int(ffn_dim_multiplier * ff_dim)
7099+
ff_dim = multiple_of * ((ff_dim + multiple_of - 1) // multiple_of)
7100+
7101+
self.gguf_writer.add_feed_forward_length(ff_dim)
7102+
7103+
def set_gguf_parameters(self):
7104+
# set num_key_value_heads only for attention layers
7105+
self.hparams["num_key_value_heads"] = [
7106+
self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0
7107+
for layer_type in self.hparams["layer_types"]
7108+
]
7109+
7110+
super().set_gguf_parameters()
7111+
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
7112+
self.gguf_writer.add_shortconv_l_cache(self.hparams["conv_L_cache"])
7113+
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["norm_eps"])
7114+
self._add_feed_forward_length()
7115+
7116+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
7117+
# conv op requires 2d tensor
7118+
if 'conv.conv' in name:
7119+
data_torch = data_torch.squeeze(1)
7120+
7121+
return [(self.map_tensor_name(name), data_torch)]
7122+
7123+
69467124
###### CONVERSION LOGIC ######
69477125

69487126

convert_hf_to_gguf_update.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,8 @@ class TOKENIZER_TYPE(IntEnum):
129129
{"name": "pixtral", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
130130
{"name": "seed-coder", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
131131
{"name": "a.x-4.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/skt/A.X-4.0", },
132+
{"name": "midm-2.0", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/K-intelligence/Midm-2.0-Base-Instruct", },
133+
{"name": "lfm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LiquidAI/LFM2-Tokenizer"},
132134
]
133135

134136
# some models are known to be broken upstream, so we will skip them as exceptions

0 commit comments

Comments
 (0)