Skip to content

Commit d1d250e

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents ab56b24 + 128d522 commit d1d250e

File tree

11 files changed

+181
-74
lines changed

11 files changed

+181
-74
lines changed

common/arg.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1932,13 +1932,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
19321932
}
19331933
).set_env("LLAMA_ARG_SWA_FULL"));
19341934
add_opt(common_arg(
1935-
{"--swa-checkpoints"}, "N",
1936-
string_format("max number of SWA checkpoints per slot to create (default: %d)\n"
1937-
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_swa_checkpoints),
1935+
{"--ctx-checkpoints", "--swa-checkpoints"}, "N",
1936+
string_format("max number of context checkpoints to create per slot (default: %d)\n"
1937+
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)", params.n_ctx_checkpoints),
19381938
[](common_params & params, int value) {
1939-
params.n_swa_checkpoints = value;
1939+
params.n_ctx_checkpoints = value;
19401940
}
1941-
).set_env("LLAMA_ARG_SWA_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
1941+
).set_env("LLAMA_ARG_CTX_CHECKPOINTS").set_examples({LLAMA_EXAMPLE_SERVER}));
19421942
add_opt(common_arg(
19431943
{"--kv-unified", "-kvu"},
19441944
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"

common/chat.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -625,6 +625,7 @@ const char * common_chat_format_name(common_chat_format format) {
625625
case COMMON_CHAT_FORMAT_CONTENT_ONLY: return "Content-only";
626626
case COMMON_CHAT_FORMAT_GENERIC: return "Generic";
627627
case COMMON_CHAT_FORMAT_MISTRAL_NEMO: return "Mistral Nemo";
628+
case COMMON_CHAT_FORMAT_MAGISTRAL: return "Magistral";
628629
case COMMON_CHAT_FORMAT_LLAMA_3_X: return "Llama 3.x";
629630
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS: return "Llama 3.x with builtin tools";
630631
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: return "DeepSeek R1";
@@ -984,6 +985,65 @@ static common_chat_params common_chat_params_init_mistral_nemo(const common_chat
984985
data.format = COMMON_CHAT_FORMAT_MISTRAL_NEMO;
985986
return data;
986987
}
988+
989+
static common_chat_params common_chat_params_init_magistral(const common_chat_template & tmpl, const struct templates_params & inputs) {
990+
common_chat_params data;
991+
data.prompt = apply(tmpl, inputs);
992+
data.format = COMMON_CHAT_FORMAT_MAGISTRAL;
993+
data.preserved_tokens = {
994+
"[THINK]",
995+
"[/THINK]",
996+
};
997+
998+
if (inputs.tools.is_array() && !inputs.tools.empty()) {
999+
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
1000+
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
1001+
auto schemas = json::array();
1002+
foreach_function(inputs.tools, [&](const json & tool) {
1003+
const auto & function = tool.at("function");
1004+
schemas.push_back({
1005+
{"type", "object"},
1006+
{"properties", {
1007+
{"name", {
1008+
{"type", "string"},
1009+
{"const", function.at("name")},
1010+
}},
1011+
{"arguments", function.at("parameters")},
1012+
{"id", {
1013+
{"type", "string"},
1014+
{"pattern", "^[a-zA-Z0-9]{9}$"},
1015+
}},
1016+
}},
1017+
{"required", json::array({"name", "arguments", "id"})},
1018+
});
1019+
});
1020+
auto schema = json {
1021+
{"type", "array"},
1022+
{"items", schemas.size() == 1 ? schemas[0] : json {{"anyOf", schemas}}},
1023+
{"minItems", 1},
1024+
};
1025+
if (!inputs.parallel_tool_calls) {
1026+
schema["maxItems"] = 1;
1027+
}
1028+
builder.add_rule("root", "\"[TOOL_CALLS]\" " + builder.add_schema("tool_calls", schema));
1029+
});
1030+
data.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "[TOOL_CALLS]"});
1031+
data.preserved_tokens.push_back("[TOOL_CALLS]");
1032+
} else {
1033+
data.grammar_lazy = false;
1034+
if (!inputs.json_schema.is_null()) {
1035+
if (!inputs.grammar.empty()) {
1036+
throw std::runtime_error("Either \"json_schema\" or \"grammar\" can be specified, but not both");
1037+
}
1038+
data.grammar = json_schema_to_grammar(inputs.json_schema);
1039+
} else {
1040+
data.grammar = inputs.grammar;
1041+
}
1042+
}
1043+
1044+
return data;
1045+
}
1046+
9871047
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
9881048
if (!builder.syntax().parse_tool_calls) {
9891049
builder.add_content(builder.consume_rest());
@@ -994,6 +1054,18 @@ static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
9941054
parse_prefixed_json_tool_call_array(builder, prefix);
9951055
}
9961056

1057+
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
1058+
builder.try_parse_reasoning("[THINK]", "[/THINK]");
1059+
1060+
if (!builder.syntax().parse_tool_calls) {
1061+
builder.add_content(builder.consume_rest());
1062+
return;
1063+
}
1064+
1065+
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
1066+
parse_prefixed_json_tool_call_array(builder, prefix);
1067+
}
1068+
9971069
static common_chat_params common_chat_params_init_command_r7b(const common_chat_template & tmpl, const struct templates_params & inputs) {
9981070
common_chat_params data;
9991071

@@ -2702,6 +2774,10 @@ static common_chat_params common_chat_templates_apply_jinja(
27022774
return common_chat_params_init_llama_3_x(tmpl, params, allow_python_tag_builtin_tools);
27032775
}
27042776

2777+
if (src.find("[THINK]") != std::string::npos && src.find("[/THINK]") != std::string::npos) {
2778+
return common_chat_params_init_magistral(tmpl, params);
2779+
}
2780+
27052781
// Plain handler (no tools)
27062782
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
27072783
return common_chat_params_init_without_tools(tmpl, params);
@@ -2802,6 +2878,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
28022878
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
28032879
common_chat_parse_mistral_nemo(builder);
28042880
break;
2881+
case COMMON_CHAT_FORMAT_MAGISTRAL:
2882+
common_chat_parse_magistral(builder);
2883+
break;
28052884
case COMMON_CHAT_FORMAT_LLAMA_3_X:
28062885
common_chat_parse_llama_3_1(builder);
28072886
break;

common/chat.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ enum common_chat_format {
101101
COMMON_CHAT_FORMAT_CONTENT_ONLY,
102102
COMMON_CHAT_FORMAT_GENERIC,
103103
COMMON_CHAT_FORMAT_MISTRAL_NEMO,
104+
COMMON_CHAT_FORMAT_MAGISTRAL,
104105
COMMON_CHAT_FORMAT_LLAMA_3_X,
105106
COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
106107
COMMON_CHAT_FORMAT_DEEPSEEK_R1,

common/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ struct common_params {
424424
int32_t timeout_write = timeout_read; // http write timeout in seconds
425425
int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool)
426426
int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting
427-
int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot
427+
int32_t n_ctx_checkpoints = 3; // max number of context checkpoints per slot
428428

429429
std::string hostname = "127.0.0.1";
430430
std::string public_path = ""; // NOLINT

include/llama.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -543,6 +543,9 @@ extern "C" {
543543
// Returns true if the model is recurrent (like Mamba, RWKV, etc.)
544544
LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
545545

546+
// Returns true if the model is hybrid (like Jamba, Granite, etc.)
547+
LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model);
548+
546549
// Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
547550
LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
548551

@@ -791,8 +794,12 @@ extern "C" {
791794
size_t n_token_capacity,
792795
size_t * n_token_count_out);
793796

797+
// for backwards-compat
794798
#define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1
795799

800+
// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba)
801+
#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1
802+
796803
typedef uint32_t llama_state_seq_flags;
797804

798805
LLAMA_API size_t llama_state_seq_get_size_ext(

src/llama-kv-cache-iswa.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,15 @@ bool llama_kv_cache_iswa::get_can_shift() const {
220220
}
221221

222222
void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
223-
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
223+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
224224
kv_base->state_write(io, seq_id, flags);
225225
}
226226

227227
kv_swa->state_write(io, seq_id, flags);
228228
}
229229

230230
void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
231-
if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) {
231+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
232232
kv_base->state_read(io, seq_id, flags);
233233
}
234234

src/llama-memory-hybrid.cpp

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,17 +175,17 @@ std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdo
175175
}
176176

177177
void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
178-
GGML_UNUSED(flags);
179-
180-
mem_attn->state_write(io, seq_id);
181-
mem_recr->state_write(io, seq_id);
178+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
179+
mem_attn->state_write(io, seq_id, flags);
180+
}
181+
mem_recr->state_write(io, seq_id, flags);
182182
}
183183

184184
void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) {
185-
GGML_UNUSED(flags);
186-
187-
mem_attn->state_read(io, seq_id);
188-
mem_recr->state_read(io, seq_id);
185+
if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) {
186+
mem_attn->state_read(io, seq_id, flags);
187+
}
188+
mem_recr->state_read(io, seq_id, flags);
189189
}
190190

191191
llama_kv_cache * llama_memory_hybrid::get_mem_attn() const {

src/llama-memory-recurrent.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) {
136136
}
137137

138138
bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) {
139+
//printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1);
139140
uint32_t new_head = size;
140141

141142
if (p0 < 0) {
@@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
156157
if (tail_id >= 0) {
157158
const auto & cell = cells[tail_id];
158159
// partial intersection is invalid
159-
if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) {
160+
if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) {
161+
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n");
160162
return false;
161163
}
162164
// invalidate tails which will be cleared
@@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos
167169
} else {
168170
// seq_id is negative, then the range should include everything or nothing
169171
if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits<llama_pos>::max())) {
172+
//printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n");
170173
return false;
171174
}
172175
}

src/llama-model.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20151,6 +20151,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
2015120151
return llm_arch_is_recurrent(model->arch);
2015220152
}
2015320153

20154+
bool llama_model_is_hybrid(const llama_model * model) {
20155+
return llm_arch_is_hybrid(model->arch);
20156+
}
20157+
2015420158
bool llama_model_is_diffusion(const llama_model * model) {
2015520159
return llm_arch_is_diffusion(model->arch);
2015620160
}

tests/test-chat.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ const common_chat_msg message_assist_thoughts_unparsed_md = simple_assis
411411
const common_chat_msg message_assist_thoughts_unparsed_md_partial = simple_assist_msg("<think>I'm\nthinking</think>Hello, world!\nWhat's up?\n```json\n{}");
412412

413413
const common_chat_msg message_assist_thoughts_unparsed_r7b = simple_assist_msg("<|START_THINKING|>I'm\nthinking<|END_THINKING|>Hello, world!\nWhat's up?");
414+
const common_chat_msg message_assist_thoughts_unparsed_magistral = simple_assist_msg("[THINK]raisonnement[/THINK]Réponse");
414415
const common_chat_msg message_assist_thoughts = simple_assist_msg("Hello, world!\nWhat's up?", "I'm\nthinking");
415416
const common_chat_msg message_assist_thoughts_unopened_unparsed = simple_assist_msg("I'm\nthinking</think>Hello, world!\nWhat's up?");
416417
const common_chat_msg message_assist_thoughts_no_content = simple_assist_msg("", "I'm\nthinking");
@@ -745,6 +746,17 @@ static void test_template_output_parsers() {
745746
tmpls.get(), end_tokens, message_assist_call_id, tools,
746747
"[TOOL_CALLS][{\"name\": \"special_function\", \"arguments\": {\"arg1\": 1}, \"id\": \"123456789\"}]");
747748
}
749+
{
750+
assert_msg_equals(
751+
simple_assist_msg("Réponse", "raisonnement"),
752+
common_chat_parse(
753+
message_assist_thoughts_unparsed_magistral.content,
754+
/* is_partial= */ false,
755+
{
756+
/* .format = */ COMMON_CHAT_FORMAT_MAGISTRAL,
757+
/* .reasoning_format = */ COMMON_REASONING_FORMAT_AUTO,
758+
}));
759+
}
748760
{
749761
auto tmpls = read_templates("models/templates/Qwen-QwQ-32B.jinja");
750762
std::vector<std::string> end_tokens{ "<|im_end|>" };

0 commit comments

Comments
 (0)