Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CODEOWNERS
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
/ggml/src/ggml-impl.h @ggerganov @slaren
/ggml/src/ggml-metal/ @ggerganov
/ggml/src/ggml-opencl/ @lhez @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky
/ggml/src/ggml-hexagon/ @max-krasnyansky @lhez
/ggml/src/ggml-opt.cpp @JohannesGaessler
/ggml/src/ggml-quants.* @ggerganov
/ggml/src/ggml-rpc/ @rgerganov
Expand Down
2 changes: 1 addition & 1 deletion common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3248,7 +3248,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg(
{"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
[](common_params & params, const std::string & value) {
params.embd_out = value;
}
Expand Down
198 changes: 198 additions & 0 deletions common/chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,11 @@
#include <minja/chat-template.hpp>
#include <minja/minja.hpp>

#include <algorithm>
#include <cstdio>
#include <cctype>
#include <exception>
#include <functional>
#include <iostream>
#include <optional>
#include <stdexcept>
Expand Down Expand Up @@ -640,6 +643,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
case COMMON_CHAT_FORMAT_APERTUS: return "Apertus";
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS: return "LFM2 with JSON tools";
default:
throw std::runtime_error("Unknown chat format");
}
Expand Down Expand Up @@ -986,6 +990,126 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
return data;
}


// Case-insensitive find
static size_t ifind_string(const std::string & haystack, const std::string & needle, size_t pos = 0) {
auto it = std::search(
haystack.begin() + pos, haystack.end(),
needle.begin(), needle.end(),
[](char a, char b) { return std::tolower(a) == std::tolower(b); }
);
return (it == haystack.end()) ? std::string::npos : std::distance(haystack.begin(), it);
}

static common_chat_params common_chat_params_init_lfm2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
const auto is_json_schema_provided = !inputs.json_schema.is_null();
const auto is_grammar_provided = !inputs.grammar.empty();
const auto are_tools_provided = inputs.tools.is_array() && !inputs.tools.empty();

// the logic requires potentially modifying the messages
auto tweaked_messages = inputs.messages;

auto replace_json_schema_marker = [](json & messages) -> bool {
static std::string marker1 = "force json schema.\n";
static std::string marker2 = "force json schema.";

if (messages.empty() || messages.at(0).at("role") != "system") {
return false;
}

std::string content = messages.at(0).at("content");

for (const auto & marker : {marker1, marker2}) {
const auto pos = ifind_string(content, marker);
if (pos != std::string::npos) {
content.replace(pos, marker.length(), "");
// inject modified content back into the messages
messages.at(0).at("content") = content;
return true;
}
}

return false;
};

// Lfm2 model does not natively work with json, but can generally understand the tools structure
//
// Example of the pytorch dialog structure:
// <|startoftext|><|im_start|>system
// List of tools: <|tool_list_start|>[{"name": "get_candidate_status", "description": "Retrieves the current status of a candidate in the recruitment process", "parameters": {"type": "object", "properties": {"candidate_id": {"type": "string", "description": "Unique identifier for the candidate"}}, "required": ["candidate_id"]}}]<|tool_list_end|><|im_end|>
// <|im_start|>user
// What is the current status of candidate ID 12345?<|im_end|>
// <|im_start|>assistant
// <|tool_call_start|>[get_candidate_status(candidate_id="12345")]<|tool_call_end|>Checking the current status of candidate ID 12345.<|im_end|>
// <|im_start|>tool
// <|tool_response_start|>{"candidate_id": "12345", "status": "Interview Scheduled", "position": "Clinical Research Associate", "date": "2023-11-20"}<|tool_response_end|><|im_end|>
// <|im_start|>assistant
// The candidate with ID 12345 is currently in the "Interview Scheduled" stage for the position of Clinical Research Associate, with an interview date set for 2023-11-20.<|im_end|>
//
// For the llama server compatibility with json tools semantic,
// the client can add "Follow json schema." line into the system message prompt to force the json output.
//
if (are_tools_provided && (is_json_schema_provided || is_grammar_provided)) {
// server/utils.hpp prohibits that branch for the custom grammar anyways
throw std::runtime_error("Tools call must not use \"json_schema\" or \"grammar\", use non-tool invocation if you want to use custom grammar");
} else if (are_tools_provided && replace_json_schema_marker(tweaked_messages)) {
LOG_INF("%s: Using tools to build a grammar\n", __func__);

data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{"type", "object"},
{"properties", {
{"name", {
{"type", "string"},
{"const", function.at("name")},
}},
{"arguments", function.at("parameters")},
}},
{"required", json::array({"name", "arguments", "id"})},
});
});
auto schema = json {
{"type", "array"},
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
{"minItems", 1},
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}

builder.add_rule("root", "\"<|tool_call_start|>\"" + builder.add_schema("tool_calls", schema) + "\"<|tool_call_end|>\"");
});
// model has no concept of tool selection mode choice,
// if the system prompt rendered correctly it will produce a tool call
// the grammar goes inside the tool call body
data.grammar_lazy = true;
data.grammar_triggers = {{COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, "\\s*<\\|tool_call_start\\|>\\s*\\["}};
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
data.format = COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS;
} else if (are_tools_provided && (!is_json_schema_provided && !is_grammar_provided)) {
LOG_INF("%s: Using tools without json schema or grammar\n", __func__);
// output those tokens
data.preserved_tokens = {"<|tool_call_start|>", "<|tool_call_end|>"};
} else if (is_json_schema_provided) {
LOG_INF("%s: Using provided json schema to build a grammar\n", __func__);
data.grammar = json_schema_to_grammar(inputs.json_schema);
} else if (is_grammar_provided) {
LOG_INF("%s: Using provided grammar\n", __func__);
data.grammar = inputs.grammar;
} else {
LOG_INF("%s: Using content relying on the template\n", __func__);
}

data.prompt = apply(tmpl, inputs, /* messages_override= */ tweaked_messages);
LOG_DBG("%s: Prompt: %s\n", __func__, data.prompt.c_str());

return data;
}

static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
data.prompt = apply(tmpl, inputs);
Expand Down Expand Up @@ -2499,6 +2623,71 @@ static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
builder.add_content(builder.consume_rest());
}


static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}

// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));

// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);

// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();

// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}

// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}

if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}

std::string function_name = tool_call.at("name");
std::string arguments = "{}";

if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}

if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}

// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}

// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}

static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
Expand Down Expand Up @@ -2748,6 +2937,12 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_apertus(tmpl, params);
}

// LFM2 (w/ tools)
if (src.find("List of tools: <|tool_list_start|>[") != std::string::npos &&
src.find("]<|tool_list_end|>") != std::string::npos) {
return common_chat_params_init_lfm2(tmpl, params);
}

// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
Expand Down Expand Up @@ -2926,6 +3121,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
Expand Down
1 change: 1 addition & 0 deletions common/chat.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_APERTUS,
COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,

COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
Expand Down
22 changes: 19 additions & 3 deletions common/json-schema-to-grammar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,10 @@ class SchemaConverter {
}

std::string _resolve_ref(const std::string & ref) {
std::string ref_name = ref.substr(ref.find_last_of('/') + 1);
auto it = ref.find('#');
std::string ref_fragment = it != std::string::npos ? ref.substr(it + 1) : ref;
static const std::regex nonalphanumeric_regex(R"([^a-zA-Z0-9-]+)");
std::string ref_name = "ref" + std::regex_replace(ref_fragment, nonalphanumeric_regex, "-");
if (_rules.find(ref_name) == _rules.end() && _refs_being_resolved.find(ref) == _refs_being_resolved.end()) {
_refs_being_resolved.insert(ref);
json resolved = _refs[ref];
Expand Down Expand Up @@ -774,11 +777,24 @@ class SchemaConverter {
std::vector<std::string> tokens = string_split(pointer, "/");
for (size_t i = 1; i < tokens.size(); ++i) {
std::string sel = tokens[i];
if (target.is_null() || !target.contains(sel)) {
if (target.is_object() && target.contains(sel)) {
target = target[sel];
} else if (target.is_array()) {
size_t sel_index;
try {
sel_index = std::stoul(sel);
} catch (const std::invalid_argument & e) {
sel_index = target.size();
}
if (sel_index >= target.size()) {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel_index];
} else {
_errors.push_back("Error resolving ref " + ref + ": " + sel + " not in " + target.dump());
return;
}
target = target[sel];
}
_refs[ref] = target;
}
Expand Down
26 changes: 24 additions & 2 deletions convert_hf_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,18 +2460,21 @@ def set_gguf_parameters(self):
)
class LlavaVisionModel(MmprojModel):
img_break_tok_id = -1
use_break_tok = True

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
if self.hparams.get("model_type") == "pixtral":
# layer_norm_eps is not in config.json, it is hard-coded in modeling_pixtral.py
self.hparams["layer_norm_eps"] = self.hparams.get("layer_norm_eps", 1e-5)
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
if self.use_break_tok:
self.img_break_tok_id = self.get_token_id("[IMG_BREAK]")
elif self.is_mistral_format:
# hparams is already vision config here so norm_eps is only defined in global_config.
self.hparams["norm_eps"] = self.global_config.get("norm_eps", None)
assert self.hparams["norm_eps"] is not None, "norm_eps not found in params.json"
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
if self.use_break_tok:
self.img_break_tok_id = self.find_vparam(["image_break_token_id"])
else:
raise ValueError(f"Unsupported model type: {self.hparams['model_type']}")
logger.info(f"Image break token id: {self.img_break_tok_id}")
Expand Down Expand Up @@ -3962,6 +3965,10 @@ def _get_cls_out_tensor(self, data_torch: Tensor) -> Tensor:
return torch.stack([true_row, false_row], dim=0)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if "model.vision_" in name:
# skip multimodal tensors
return []

if self.is_rerank:
is_tied_head = self.is_tied_embeddings and "embed_tokens" in name
is_real_head = not self.is_tied_embeddings and "lm_head" in name
Expand Down Expand Up @@ -9435,6 +9442,21 @@ def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", "
return super().map_tensor_name(name, try_suffixes)


@ModelBase.register("LightOnOCRForConditionalGeneration")
class LightOnOCRVisionModel(LlavaVisionModel):
is_mistral_format = False
use_break_tok = False

def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.LIGHTONOCR)

def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
name = name.replace("model.vision_encoder.", "vision_tower.")
name = name.replace("model.vision_projection.", "multi_modal_projector.")
return super().modify_tensors(data_torch, name, bid)


@ModelBase.register("KimiVLForConditionalGeneration")
class KimiVLModel(MmprojModel):
def __init__(self, *args, **kwargs):
Expand Down
Loading
Loading