Skip to content

Commit 57db9d7

Browse files
server: address review feedback from ngxson
Move minimax-m2 prefix injection logic from server.cpp to chat.cpp via common_chat_stream_state
1 parent 39351b1 commit 57db9d7

File tree

3 files changed

+79
-43
lines changed

3 files changed

+79
-43
lines changed

common/chat.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -676,6 +676,35 @@ common_reasoning_format common_reasoning_format_from_name(const std::string & fo
676676
throw std::runtime_error("Unknown reasoning format: " + format);
677677
}
678678

679+
void common_chat_stream_state::init(const common_chat_syntax & syntax) {
680+
reasoning_prefix_streamed_ = false;
681+
682+
if (syntax.reasoning_format == COMMON_REASONING_FORMAT_MINIMAX_M2) {
683+
reasoning_prefix_ = "<think>\n";
684+
} else {
685+
reasoning_prefix_.clear();
686+
}
687+
}
688+
689+
std::string common_chat_stream_state::apply_reasoning_prefix(const std::string & text) const {
690+
if (reasoning_prefix_.empty()) {
691+
return text;
692+
}
693+
694+
std::string result(reasoning_prefix_);
695+
result += text;
696+
return result;
697+
}
698+
699+
std::optional<std::string> common_chat_stream_state::consume_reasoning_prefix() {
700+
if (!reasoning_prefix_pending()) {
701+
return std::nullopt;
702+
}
703+
704+
reasoning_prefix_streamed_ = true;
705+
return reasoning_prefix_;
706+
}
707+
679708
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
680709
std::string arguments;
681710
if (builder.is_partial()) {
@@ -3154,3 +3183,12 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
31543183
}
31553184
return msg;
31563185
}
3186+
3187+
common_chat_msg common_chat_parse_stream(
3188+
const std::string & input,
3189+
bool is_partial,
3190+
common_chat_stream_state & stream_state,
3191+
const common_chat_syntax & syntax) {
3192+
const auto text_to_parse = stream_state.apply_reasoning_prefix(input);
3193+
return common_chat_parse(text_to_parse, is_partial, syntax);
3194+
}

common/chat.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <string>
99
#include <vector>
1010
#include <map>
11+
#include <optional>
1112

1213
struct common_chat_templates;
1314

@@ -159,6 +160,26 @@ struct common_chat_syntax {
159160
bool parse_tool_calls = true;
160161
};
161162

163+
struct common_chat_stream_state {
164+
common_chat_stream_state() = default;
165+
explicit common_chat_stream_state(const common_chat_syntax & syntax) { init(syntax); }
166+
167+
void init(const common_chat_syntax & syntax);
168+
169+
std::string apply_reasoning_prefix(const std::string & text) const;
170+
171+
std::optional<std::string> consume_reasoning_prefix();
172+
173+
bool has_reasoning_prefix() const { return !reasoning_prefix_.empty(); }
174+
bool reasoning_prefix_pending() const { return has_reasoning_prefix() && !reasoning_prefix_streamed_; }
175+
const std::string & reasoning_prefix() const { return reasoning_prefix_; }
176+
void mark_reasoning_prefix_streamed() { reasoning_prefix_streamed_ = true; }
177+
178+
private:
179+
std::string reasoning_prefix_;
180+
bool reasoning_prefix_streamed_ = false;
181+
};
182+
162183
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
163184
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
164185

@@ -200,6 +221,11 @@ const char* common_chat_format_name(common_chat_format format);
200221
const char* common_reasoning_format_name(common_reasoning_format format);
201222
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
202223
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
224+
common_chat_msg common_chat_parse_stream(
225+
const std::string & input,
226+
bool is_partial,
227+
common_chat_stream_state & stream_state,
228+
const common_chat_syntax & syntax);
203229

204230
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
205231

tools/server/server.cpp

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,7 @@ struct server_slot {
16631663
bool has_next_token = true;
16641664
bool has_new_line = false;
16651665
bool truncated = false;
1666-
bool minimax_reasoning_prefix_injected = false;
1667-
bool minimax_reasoning_prefix_streamed = false;
1666+
common_chat_stream_state reasoning_stream_state;
16681667

16691668
stop_type stop;
16701669

@@ -1735,8 +1734,7 @@ struct server_slot {
17351734
generated_text = "";
17361735
has_new_line = false;
17371736
truncated = false;
1738-
minimax_reasoning_prefix_injected = false;
1739-
minimax_reasoning_prefix_streamed = false;
1737+
reasoning_stream_state = {};
17401738
stop = STOP_TYPE_NONE;
17411739
stopping_word = "";
17421740
n_sent_text = 0;
@@ -1863,14 +1861,12 @@ struct server_slot {
18631861
GGML_ASSERT(task);
18641862

18651863
auto previous_msg = chat_msg;
1866-
std::string text_to_parse = generated_text;
1867-
if (minimax_reasoning_prefix_injected) {
1868-
text_to_parse.insert(0, "<think>\n");
1869-
}
1864+
const auto text_to_parse = reasoning_stream_state.apply_reasoning_prefix(generated_text);
18701865
SRV_DBG("Parsing chat message: %s\n", text_to_parse.c_str());
1871-
auto new_msg = common_chat_parse(
1872-
text_to_parse,
1866+
auto new_msg = common_chat_parse_stream(
1867+
generated_text,
18731868
/* is_partial= */ stop != STOP_TYPE_EOS,
1869+
reasoning_stream_state,
18741870
task->params.oaicompat_chat_syntax);
18751871
if (!new_msg.empty()) {
18761872
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
@@ -2804,10 +2800,7 @@ struct server_context {
28042800

28052801
slot.state = SLOT_STATE_STARTED;
28062802

2807-
const bool needs_minimax_prefix =
2808-
slot.task->params.oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_MINIMAX_M2;
2809-
slot.minimax_reasoning_prefix_injected = needs_minimax_prefix;
2810-
slot.minimax_reasoning_prefix_streamed = false;
2803+
slot.reasoning_stream_state.init(slot.task->params.oaicompat_chat_syntax);
28112804

28122805
SLT_INF(slot, "%s", "processing task\n");
28132806

@@ -2869,25 +2862,16 @@ struct server_context {
28692862
slot.add_token(result);
28702863
result.text_to_send = std::move(delta_to_send);
28712864

2872-
auto stream_with_minimax_prefix = [&](const completion_token_output & chunk) {
2873-
if (!slot.task->params.stream) {
2874-
return;
2875-
}
2876-
2877-
if (slot.minimax_reasoning_prefix_injected && !slot.minimax_reasoning_prefix_streamed) {
2865+
if (send_text && slot.task->params.stream) {
2866+
if (auto prefix = slot.reasoning_stream_state.consume_reasoning_prefix()) {
28782867
completion_token_output prefix_chunk{};
28792868
prefix_chunk.tok = LLAMA_TOKEN_NULL;
28802869
prefix_chunk.prob = 0.0f;
2881-
prefix_chunk.text_to_send = "<think>\n";
2870+
prefix_chunk.text_to_send = *prefix;
28822871
send_partial_response(slot, prefix_chunk, false);
2883-
slot.minimax_reasoning_prefix_streamed = true;
28842872
}
28852873

2886-
send_partial_response(slot, chunk, false);
2887-
};
2888-
2889-
if (send_text) {
2890-
stream_with_minimax_prefix(result);
2874+
send_partial_response(slot, result, false);
28912875
}
28922876
}
28932877

@@ -3058,11 +3042,7 @@ struct server_context {
30583042
return true;
30593043
}
30603044

3061-
void send_partial_response(
3062-
server_slot & slot,
3063-
const completion_token_output & tkn,
3064-
bool is_progress,
3065-
const std::vector<common_chat_msg_diff> * forced_diffs = nullptr) {
3045+
void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
30663046
auto res = std::make_unique<server_task_result_cmpl_partial>();
30673047

30683048
res->id = slot.task->id;
@@ -3080,11 +3060,7 @@ struct server_context {
30803060
res->tokens = { tkn.tok };
30813061
}
30823062

3083-
if (forced_diffs) {
3084-
res->oaicompat_msg_diffs = *forced_diffs;
3085-
} else {
3086-
slot.update_chat_msg(res->oaicompat_msg_diffs);
3087-
}
3063+
slot.update_chat_msg(res->oaicompat_msg_diffs);
30883064
}
30893065

30903066
res->n_decoded = slot.n_decoded;
@@ -3115,12 +3091,8 @@ struct server_context {
31153091
res->id = slot.task->id;
31163092
res->id_slot = slot.id;
31173093

3118-
res->index = slot.task->index;
3119-
std::string response_content = slot.generated_text;
3120-
if (slot.minimax_reasoning_prefix_injected) {
3121-
response_content.insert(0, "<think>\n");
3122-
}
3123-
res->content = std::move(response_content);
3094+
res->index = slot.task->index;
3095+
res->content = slot.reasoning_stream_state.apply_reasoning_prefix(slot.generated_text);
31243096
res->tokens = std::move(slot.generated_tokens);
31253097
res->timings = slot.get_timings();
31263098
res->prompt = slot.task->tokens.detokenize(ctx, true);

0 commit comments

Comments
 (0)