Skip to content

Commit 9e02f27

Browse files
authored
Merge branch 'master' into master
2 parents d7e2a73 + 3479efd commit 9e02f27

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

52 files changed

+2399
-419
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
8484
- [X] [Mistral 7B](https://huggingface.co/mistralai/Mistral-7B-v0.1)
8585
- [x] [Mixtral MoE](https://huggingface.co/models?search=mistral-ai/Mixtral)
8686
- [x] [DBRX](https://huggingface.co/databricks/dbrx-instruct)
87+
- [x] [Jamba](https://huggingface.co/ai21labs)
8788
- [X] [Falcon](https://huggingface.co/models?search=tiiuae/falcon)
8889
- [X] [Chinese LLaMA / Alpaca](https://github.com/ymcui/Chinese-LLaMA-Alpaca) and [Chinese LLaMA-2 / Alpaca-2](https://github.com/ymcui/Chinese-LLaMA-Alpaca-2)
8990
- [X] [Vigogne (French)](https://github.com/bofenghuang/vigogne)

common/chat.cpp

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@
99
#include <minja/chat-template.hpp>
1010
#include <minja/minja.hpp>
1111

12+
#include <algorithm>
1213
#include <cstdio>
14+
#include <cctype>
1315
#include <exception>
16+
#include <functional>
1417
#include <iostream>
1518
#include <optional>
1619
#include <stdexcept>
@@ -643,6 +646,7 @@ const char * common_chat_format_name(common_chat_format format) {
643646
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
644647
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
645648
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
649+
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
646650
case COMMON_CHAT_FORMAT_GLM_4_5: return "GLM 4.5";
647651
default:
648652
throw std::runtime_error("Unknown chat format");
@@ -990,6 +994,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
990994
return data;
991995
}
992996

997+
998+
// Case-insensitive find
999+
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
1000+
auto it = std::search(
1001+
haystack.begin() + pos, haystack.end(),
1002+
needle.begin(), needle.end(),
1003+
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
1004+
);
1005+
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
1006+
}
1007+
1008+
static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
1009+
common_chat_params data;
1010+
const auto is_json_schema_provided = !inputs.json_schema.is_null();
1011+
const auto is_grammar_provided = !inputs.grammar.empty();
1012+
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();
1013+
1014+
// the logic requires potentially modifying the messages
1015+
auto tweaked_messages = inputs.messages;
1016+
1017+
auto replace_json_schema_marker = [](json & messages) -> bool {
1018+
static std::string marker1 = "force json schema.\n";
1019+
static std::string marker2 = "force json schema.";
1020+
1021+
if (messages.empty() || messages.at(0).at("role") != "system") {
1022+
return false;
1023+
}
1024+
1025+
std::string content = messages.at(0).at("content");
1026+
1027+
for (const auto & marker : {marker1, marker2}) {
1028+
const auto pos = ifind_string(content, marker);
1029+
if (pos != std::string::npos) {
1030+
content.replace(pos, marker.length(), "");
1031+
// inject modified content back into the messages
1032+
messages.at(0).at("content") = content;
1033+
return true;
1034+
}
1035+
}
1036+
1037+
return false;
1038+
};
1039+
1040+
// Lfm2 model does not natively work with json, but can generally understand the tools structure
1041+
//
1042+
// Example of the pytorch dialog structure:
1043+
// <|startoftext|><|im_start|>system
1044+
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
1045+
// <|im_start|>user
1046+
// What is the current status of candidate ID 12345?<|im_end|>
1047+
// <|im_start|>assistant
1048+
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
1049+
// <|im_start|>tool
1050+
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
1051+
// <|im_start|>assistant
1052+
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
1053+
//
1054+
// For the llama server compatibility with json tools semantic,
1055+
// the client can add "Follow json schema." line into the system message prompt to force the json output.
1056+
//
1057+
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
1058+
// server/utils.hpp prohibits that branch for the custom grammar anyways
1059+
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
1060+
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
1061+
LOG_INF("%s: Using tools to build a grammar\n", __func__);
1062+
1063+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1064+
auto schemas = json::array();
1065+
foreach_function(inputs.tools, [&](const json & tool) {
1066+
const auto & function = tool.at("function");
1067+
schemas.push_back({
1068+
{"type", "object"},
1069+
{"properties", {
1070+
{"name", {
1071+
{"type", "string"},
1072+
{"const", function.at("name")},
1073+
}},
1074+
{"arguments", function.at("parameters")},
1075+
}},
1076+
{"required", json::array({"name", "arguments", "id"})},
1077+
});
1078+
});
1079+
auto schema = json {
1080+
{"type", "array"},
1081+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1082+
{"minItems", 1},
1083+
};
1084+
if (!inputs.parallel_tool_calls) {
1085+
schema["maxItems"] = 1;
1086+
}
1087+
1088+
builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
1089+
});
1090+
// model has no concept of tool selection mode choice,
1091+
// if the system prompt rendered correctly it will produce a tool call
1092+
// the grammar goes inside the tool call body
1093+
data.grammar_lazy = true;
1094+
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
1095+
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
1096+
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
1097+
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
1098+
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
1099+
// output those tokens
1100+
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
1101+
} else if (is_json_schema_provided) {
1102+
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
1103+
data.grammar = json_schema_to_grammar(inputs.json_schema);
1104+
} else if (is_grammar_provided) {
1105+
LOG_INF("%s: Using provided grammar\n", __func__);
1106+
data.grammar = inputs.grammar;
1107+
} else {
1108+
LOG_INF("%s: Using content relying on the template\n", __func__);
1109+
}
1110+
1111+
data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
1112+
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());
1113+
1114+
return data;
1115+
}
1116+
9931117
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
9941118
common_chat_params data;
9951119
data.prompt = apply(tmpl, inputs);
@@ -2891,6 +3015,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
28913015
builder.add_content(builder.consume_rest());
28923016
}
28933017

3018+
3019+
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
3020+
if (!builder.syntax().parse_tool_calls) {
3021+
builder.add_content(builder.consume_rest());
3022+
return;
3023+
}
3024+
3025+
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
3026+
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
3027+
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
3028+
3029+
// Loop through all tool calls
3030+
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
3031+
builder.move_to(res->groups[0].end);
3032+
3033+
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
3034+
auto tool_calls_data = builder.consume_json();
3035+
3036+
// Consume end marker
3037+
builder.consume_spaces();
3038+
if (!builder.try_consume_regex(tool_call_end_regex)) {
3039+
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
3040+
}
3041+
3042+
// Process each tool call in the array
3043+
if (tool_calls_data.json.is_array()) {
3044+
for (const auto & tool_call : tool_calls_data.json) {
3045+
if (!tool_call.is_object()) {
3046+
throw common_chat_msg_partial_exception("Tool call must be an object");
3047+
}
3048+
3049+
if (!tool_call.contains("name")) {
3050+
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
3051+
}
3052+
3053+
std::string function_name = tool_call.at("name");
3054+
std::string arguments = "{}";
3055+
3056+
if (tool_call.contains("arguments")) {
3057+
if (tool_call.at("arguments").is_object()) {
3058+
arguments = tool_call.at("arguments").dump();
3059+
} else if (tool_call.at("arguments").is_string()) {
3060+
arguments = tool_call.at("arguments");
3061+
}
3062+
}
3063+
3064+
if (!builder.add_tool_call(function_name, "", arguments)) {
3065+
throw common_chat_msg_partial_exception("Incomplete tool call");
3066+
}
3067+
}
3068+
} else {
3069+
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
3070+
}
3071+
3072+
// Consume any trailing whitespace after this tool call
3073+
builder.consume_spaces();
3074+
}
3075+
3076+
// Consume any remaining content after all tool calls
3077+
auto remaining = builder.consume_rest();
3078+
if (!string_strip(remaining).empty()) {
3079+
builder.add_content(remaining);
3080+
}
3081+
}
3082+
28943083
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
28953084
// Parse thinking tags first - this handles the main reasoning content
28963085
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@@ -3145,6 +3334,12 @@ static common_chat_params common_chat_templates_apply_jinja(
31453334
return common_chat_params_init_apertus(tmpl, params);
31463335
}
31473336

3337+
// LFM2 (w/ tools)
3338+
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
3339+
src.find("]<|tool_list_end|>") != std::string::npos) {
3340+
return common_chat_params_init_lfm2(tmpl, params);
3341+
}
3342+
31483343
// Use generic handler when mixing tools + JSON schema.
31493344
// TODO: support that mix in handlers below.
31503345
if ((params.tools.is_array() && params.json_schema.is_object())) {
@@ -3323,6 +3518,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
33233518
case COMMON_CHAT_FORMAT_APERTUS:
33243519
common_chat_parse_apertus(builder);
33253520
break;
3521+
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
3522+
common_chat_parse_lfm2(builder);
3523+
break;
33263524
case COMMON_CHAT_FORMAT_GLM_4_5:
33273525
common_chat_parse_glm_4_5(builder);
33283526
break;

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ enum common_chat_format {
116116
COMMON_CHAT_FORMAT_SEED_OSS,
117117
COMMON_CHAT_FORMAT_NEMOTRON_V2,
118118
COMMON_CHAT_FORMAT_APERTUS,
119+
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
119120
COMMON_CHAT_FORMAT_GLM_4_5,
120121

121122
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats

convert_hf_to_gguf.py

Lines changed: 48 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -742,6 +742,12 @@ def set_gguf_parameters(self):
742742
if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None:
743743
self.gguf_writer.add_expert_used_count(n_experts_used)
744744
logger.info(f"gguf: experts used count = {n_experts_used}")
745+
if (n_expert_groups := self.hparams.get("n_group")) is not None:
746+
self.gguf_writer.add_expert_group_count(n_expert_groups)
747+
logger.info(f"gguf: expert groups count = {n_expert_groups}")
748+
if (n_group_used := self.hparams.get("topk_group")) is not None:
749+
self.gguf_writer.add_expert_group_used_count(n_group_used)
750+
logger.info(f"gguf: expert groups used count = {n_group_used}")
745751

746752
if (head_dim := self.hparams.get("head_dim")) is not None:
747753
self.gguf_writer.add_key_length(head_dim)
@@ -1497,6 +1503,17 @@ def get_audio_config(self) -> dict[str, Any] | None:
14971503
def set_type(self):
14981504
self.gguf_writer.add_type(gguf.GGUFType.MMPROJ)
14991505

1506+
def prepare_metadata(self, vocab_only: bool):
1507+
super().prepare_metadata(vocab_only=vocab_only)
1508+
1509+
output_type: str = self.ftype.name.partition("_")[2]
1510+
1511+
if self.fname_out.is_dir():
1512+
fname_default: str = gguf.naming_convention(self.metadata.name, self.metadata.basename, self.metadata.finetune, self.metadata.version, size_label=None, output_type=output_type, model_type=None)
1513+
self.fname_out = self.fname_out / f"mmproj-{fname_default}.gguf"
1514+
else:
1515+
self.fname_out = self.fname_out.parent / gguf.fill_templated_filename(self.fname_out.name, output_type)
1516+
15001517
def set_gguf_parameters(self):
15011518
self.gguf_writer.add_file_type(self.ftype)
15021519

@@ -2443,18 +2460,21 @@ def set_gguf_parameters(self):
24432460
)
24442461
class LlavaVisionModel(MmprojModel):
24452462
img_break_tok_id = -1
2463+
use_break_tok = True
24462464

24472465
def __init__(self, *args, **kwargs):
24482466
super().__init__(*args, **kwargs)
24492467
if self.hparams.get("model_type") == "pixtral":
24502468
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
24512469
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
2452-
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
2470+
if self.use_break_tok:
2471+
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
24532472
elif self.is_mistral_format:
24542473
# hparams is already vision config here so norm_eps is only defined in global_config.
24552474
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
24562475
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
2457-
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
2476+
if self.use_break_tok:
2477+
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
24582478
else:
24592479
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
24602480
logger.info(f"Image break token id: {self.img_break_tok_id}")
@@ -3945,6 +3965,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
39453965
return torch.stack([true_row, false_row], dim=0)
39463966

39473967
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
3968+
if "model.vision_" in name:
3969+
# skip multimodal tensors
3970+
return []
3971+
39483972
if self.is_rerank:
39493973
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
39503974
is_real_head = not self.is_tied_embeddings and "lm_head" in name
@@ -8222,8 +8246,6 @@ def set_gguf_parameters(self):
82228246
self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
82238247
self.gguf_writer.add_expert_count(hparams["num_experts"])
82248248
self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"])
8225-
self.gguf_writer.add_expert_group_count(hparams["n_group"])
8226-
self.gguf_writer.add_expert_group_used_count(hparams["topk_group"])
82278249
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
82288250

82298251
if hparams["score_function"] == "sigmoid":
@@ -8943,6 +8965,13 @@ def set_vocab(self):
89438965
class GptOssModel(TextModel):
89448966
model_arch = gguf.MODEL_ARCH.GPT_OSS
89458967

8968+
# TODO: remove once MXFP4 is supported more generally
8969+
def dequant_model(self):
8970+
quant_config = self.hparams.get("quantization_config")
8971+
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
8972+
return
8973+
return super().dequant_model()
8974+
89468975
def transform_nibble_layout(self, tensor):
89478976
assert tensor.dtype == torch.uint8
89488977
assert tensor.shape[-1] == 16
@@ -9413,6 +9442,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
94139442
return super().map_tensor_name(name, try_suffixes)
94149443

94159444

9445+
@ModelBase.register("LightOnOCRForConditionalGeneration")
9446+
class LightOnOCRVisionModel(LlavaVisionModel):
9447+
is_mistral_format = False
9448+
use_break_tok = False
9449+
9450+
def set_gguf_parameters(self):
9451+
super().set_gguf_parameters()
9452+
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)
9453+
9454+
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
9455+
name = name.replace("model.vision_encoder.", "vision_tower.")
9456+
name = name.replace("model.vision_projection.", "multi_modal_projector.")
9457+
return super().modify_tensors(data_torch, name, bid)
9458+
9459+
94169460
@ModelBase.register("KimiVLForConditionalGeneration")
94179461
class KimiVLModel(MmprojModel):
94189462
def __init__(self, *args, **kwargs):
@@ -9722,10 +9766,6 @@ def main() -> None:
97229766

97239767
logger.info(f"Loading model: {dir_model.name}")
97249768

9725-
if args.mmproj:
9726-
if "mmproj" not in fname_out.name:
9727-
fname_out = ModelBase.add_prefix_to_filename(fname_out, "mmproj-")
9728-
97299769
is_mistral_format = args.mistral_format
97309770
if is_mistral_format and not _mistral_common_installed:
97319771
raise ImportError(_mistral_import_error_msg)

0 commit comments

Comments
 (0)