Skip to content

Commit a61f1e2

Browse files
server: address review feedback from ngxson
Move minimax-m2 prefix injection logic from server.cpp to chat.cpp via common_chat_stream_state
1 parent b654838 commit a61f1e2

File tree

3 files changed

+79
-43
lines changed

3 files changed

+79
-43
lines changed

common/chat.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,6 +675,35 @@ common_reasoning_format common_reasoning_format_from_name(const std::string & fo
675675
throw std::runtime_error("Unknown reasoning format: " + format);
676676
}
677677

678+
void common_chat_stream_state::init(const common_chat_syntax & syntax) {
679+
reasoning_prefix_streamed_ = false;
680+
681+
if (syntax.reasoning_format == COMMON_REASONING_FORMAT_MINIMAX_M2) {
682+
reasoning_prefix_ = "<think>\n";
683+
} else {
684+
reasoning_prefix_.clear();
685+
}
686+
}
687+
688+
std::string common_chat_stream_state::apply_reasoning_prefix(const std::string & text) const {
689+
if (reasoning_prefix_.empty()) {
690+
return text;
691+
}
692+
693+
std::string result(reasoning_prefix_);
694+
result += text;
695+
return result;
696+
}
697+
698+
std::optional<std::string> common_chat_stream_state::consume_reasoning_prefix() {
699+
if (!reasoning_prefix_pending()) {
700+
return std::nullopt;
701+
}
702+
703+
reasoning_prefix_streamed_ = true;
704+
return reasoning_prefix_;
705+
}
706+
678707
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
679708
std::string arguments;
680709
if (builder.is_partial()) {
@@ -3169,3 +3198,12 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_partial, co
31693198
}
31703199
return msg;
31713200
}
3201+
3202+
common_chat_msg common_chat_parse_stream(
3203+
const std::string & input,
3204+
bool is_partial,
3205+
common_chat_stream_state & stream_state,
3206+
const common_chat_syntax & syntax) {
3207+
const auto text_to_parse = stream_state.apply_reasoning_prefix(input);
3208+
return common_chat_parse(text_to_parse, is_partial, syntax);
3209+
}

common/chat.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <string>
99
#include <vector>
1010
#include <map>
11+
#include <optional>
1112

1213
struct common_chat_templates;
1314

@@ -159,6 +160,26 @@ struct common_chat_syntax {
159160
bool parse_tool_calls = true;
160161
};
161162

163+
struct common_chat_stream_state {
164+
common_chat_stream_state() = default;
165+
explicit common_chat_stream_state(const common_chat_syntax & syntax) { init(syntax); }
166+
167+
void init(const common_chat_syntax & syntax);
168+
169+
std::string apply_reasoning_prefix(const std::string & text) const;
170+
171+
std::optional<std::string> consume_reasoning_prefix();
172+
173+
bool has_reasoning_prefix() const { return !reasoning_prefix_.empty(); }
174+
bool reasoning_prefix_pending() const { return has_reasoning_prefix() && !reasoning_prefix_streamed_; }
175+
const std::string & reasoning_prefix() const { return reasoning_prefix_; }
176+
void mark_reasoning_prefix_streamed() { reasoning_prefix_streamed_ = true; }
177+
178+
private:
179+
std::string reasoning_prefix_;
180+
bool reasoning_prefix_streamed_ = false;
181+
};
182+
162183
// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
163184
bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
164185

@@ -200,6 +221,11 @@ const char* common_chat_format_name(common_chat_format format);
200221
const char* common_reasoning_format_name(common_reasoning_format format);
201222
common_reasoning_format common_reasoning_format_from_name(const std::string & format);
202223
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
224+
common_chat_msg common_chat_parse_stream(
225+
const std::string & input,
226+
bool is_partial,
227+
common_chat_stream_state & stream_state,
228+
const common_chat_syntax & syntax);
203229

204230
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
205231

tools/server/server.cpp

Lines changed: 15 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1663,8 +1663,7 @@ struct server_slot {
16631663
bool has_next_token = true;
16641664
bool has_new_line = false;
16651665
bool truncated = false;
1666-
bool minimax_reasoning_prefix_injected = false;
1667-
bool minimax_reasoning_prefix_streamed = false;
1666+
common_chat_stream_state reasoning_stream_state;
16681667

16691668
stop_type stop;
16701669

@@ -1735,8 +1734,7 @@ struct server_slot {
17351734
generated_text = "";
17361735
has_new_line = false;
17371736
truncated = false;
1738-
minimax_reasoning_prefix_injected = false;
1739-
minimax_reasoning_prefix_streamed = false;
1737+
reasoning_stream_state = {};
17401738
stop = STOP_TYPE_NONE;
17411739
stopping_word = "";
17421740
n_sent_text = 0;
@@ -1863,14 +1861,12 @@ struct server_slot {
18631861
GGML_ASSERT(task);
18641862

18651863
auto previous_msg = chat_msg;
1866-
std::string text_to_parse = generated_text;
1867-
if (minimax_reasoning_prefix_injected) {
1868-
text_to_parse.insert(0, "<think>\n");
1869-
}
1864+
const auto text_to_parse = reasoning_stream_state.apply_reasoning_prefix(generated_text);
18701865
SRV_DBG("Parsing chat message: %s\n", text_to_parse.c_str());
1871-
auto new_msg = common_chat_parse(
1872-
text_to_parse,
1866+
auto new_msg = common_chat_parse_stream(
1867+
generated_text,
18731868
/* is_partial= */ stop != STOP_TYPE_EOS,
1869+
reasoning_stream_state,
18741870
task->params.oaicompat_chat_syntax);
18751871
if (!new_msg.empty()) {
18761872
new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id);
@@ -2843,10 +2839,7 @@ struct server_context {
28432839

28442840
slot.state = SLOT_STATE_STARTED;
28452841

2846-
const bool needs_minimax_prefix =
2847-
slot.task->params.oaicompat_chat_syntax.reasoning_format == COMMON_REASONING_FORMAT_MINIMAX_M2;
2848-
slot.minimax_reasoning_prefix_injected = needs_minimax_prefix;
2849-
slot.minimax_reasoning_prefix_streamed = false;
2842+
slot.reasoning_stream_state.init(slot.task->params.oaicompat_chat_syntax);
28502843

28512844
SLT_INF(slot, "%s", "processing task\n");
28522845

@@ -2908,25 +2901,16 @@ struct server_context {
29082901
slot.add_token(result);
29092902
result.text_to_send = std::move(delta_to_send);
29102903

2911-
auto stream_with_minimax_prefix = [&](const completion_token_output & chunk) {
2912-
if (!slot.task->params.stream) {
2913-
return;
2914-
}
2915-
2916-
if (slot.minimax_reasoning_prefix_injected && !slot.minimax_reasoning_prefix_streamed) {
2904+
if (send_text && slot.task->params.stream) {
2905+
if (auto prefix = slot.reasoning_stream_state.consume_reasoning_prefix()) {
29172906
completion_token_output prefix_chunk{};
29182907
prefix_chunk.tok = LLAMA_TOKEN_NULL;
29192908
prefix_chunk.prob = 0.0f;
2920-
prefix_chunk.text_to_send = "<think>\n";
2909+
prefix_chunk.text_to_send = *prefix;
29212910
send_partial_response(slot, prefix_chunk, false);
2922-
slot.minimax_reasoning_prefix_streamed = true;
29232911
}
29242912

2925-
send_partial_response(slot, chunk, false);
2926-
};
2927-
2928-
if (send_text) {
2929-
stream_with_minimax_prefix(result);
2913+
send_partial_response(slot, result, false);
29302914
}
29312915
}
29322916

@@ -3097,11 +3081,7 @@ struct server_context {
30973081
return true;
30983082
}
30993083

3100-
void send_partial_response(
3101-
server_slot & slot,
3102-
const completion_token_output & tkn,
3103-
bool is_progress,
3104-
const std::vector<common_chat_msg_diff> * forced_diffs = nullptr) {
3084+
void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) {
31053085
auto res = std::make_unique<server_task_result_cmpl_partial>();
31063086

31073087
res->id = slot.task->id;
@@ -3119,11 +3099,7 @@ struct server_context {
31193099
res->tokens = { tkn.tok };
31203100
}
31213101

3122-
if (forced_diffs) {
3123-
res->oaicompat_msg_diffs = *forced_diffs;
3124-
} else {
3125-
slot.update_chat_msg(res->oaicompat_msg_diffs);
3126-
}
3102+
slot.update_chat_msg(res->oaicompat_msg_diffs);
31273103
}
31283104

31293105
res->n_decoded = slot.n_decoded;
@@ -3154,12 +3130,8 @@ struct server_context {
31543130
res->id = slot.task->id;
31553131
res->id_slot = slot.id;
31563132

3157-
res->index = slot.task->index;
3158-
std::string response_content = slot.generated_text;
3159-
if (slot.minimax_reasoning_prefix_injected) {
3160-
response_content.insert(0, "<think>\n");
3161-
}
3162-
res->content = std::move(response_content);
3133+
res->index = slot.task->index;
3134+
res->content = slot.reasoning_stream_state.apply_reasoning_prefix(slot.generated_text);
31633135
res->tokens = std::move(slot.generated_tokens);
31643136
res->timings = slot.get_timings();
31653137
res->prompt = slot.task->tokens.detokenize(ctx, true);

0 commit comments

Comments
 (0)