Skip to content

Commit d230722

Browse files
committed
Add partial formatter
1 parent 7adc79c commit d230722

File tree

1 file changed

+107
-11
lines changed

1 file changed

+107
-11
lines changed

tools/main/main.cpp

Lines changed: 107 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,102 @@ static void sigint_handler(int signo) {
8383
}
8484
#endif
8585

86+
class partial_formatter {
87+
public:
88+
enum output_type {
89+
CONTENT,
90+
REASONING,
91+
};
92+
93+
struct output {
94+
std::string formatted;
95+
output_type type;
96+
};
97+
98+
partial_formatter(const common_chat_syntax & syntax) : syntax_(syntax), had_reasoning_(false) {}
99+
100+
std::vector<output> operator()(const std::string & accumulated) {
101+
common_chat_msg next = common_chat_parse(accumulated, true, syntax_);
102+
103+
auto diffs = common_chat_msg_diff::compute_diffs(previous_, next);
104+
std::vector<output> result;
105+
for (const auto & diff : diffs) {
106+
if (!diff.reasoning_content_delta.empty()) {
107+
result.push_back({diff.reasoning_content_delta, REASONING});
108+
had_reasoning_ = true;
109+
}
110+
if (!diff.content_delta.empty()) {
111+
if (had_reasoning_) {
112+
result.push_back({"\n", REASONING});
113+
had_reasoning_ = false;
114+
}
115+
result.push_back({diff.content_delta, CONTENT});
116+
}
117+
}
118+
previous_ = next;
119+
return result;
120+
}
121+
122+
private:
123+
common_chat_syntax syntax_;
124+
common_chat_msg previous_;
125+
bool had_reasoning_;
126+
};
127+
128+
class chat_formatter {
129+
public:
130+
chat_formatter(
131+
std::vector<common_chat_msg> & chat_msgs,
132+
const common_chat_templates_ptr & chat_templates,
133+
const common_params & params)
134+
: chat_msgs_(chat_msgs),
135+
chat_templates_(chat_templates),
136+
params_(params) {}
137+
138+
std::string operator()(const std::string & role, const std::string & content) {
139+
common_chat_msg new_msg;
140+
new_msg.role = role;
141+
new_msg.content = content;
142+
chat_msgs_.push_back(new_msg);
143+
144+
common_chat_templates_inputs cinputs;
145+
cinputs.use_jinja = params_.use_jinja;
146+
cinputs.messages = chat_msgs_;
147+
cinputs.add_generation_prompt = (role == "user");
148+
cinputs.reasoning_format = params_.reasoning_format;
149+
150+
cinputs.enable_thinking =
151+
params_.use_jinja && params_.reasoning_budget != 0 &&
152+
common_chat_templates_support_enable_thinking(chat_templates_.get());
153+
154+
common_chat_params cparams = common_chat_templates_apply(chat_templates_.get(), cinputs);
155+
156+
if (!partial_formatter_ptr_ && params_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
157+
common_chat_syntax chat_syntax;
158+
chat_syntax.format = cparams.format;
159+
chat_syntax.reasoning_format = params_.reasoning_format;
160+
chat_syntax.thinking_forced_open = cparams.thinking_forced_open;
161+
chat_syntax.parse_tool_calls = false;
162+
partial_formatter_ptr_ = std::make_unique<partial_formatter>(chat_syntax);
163+
}
164+
165+
std::string formatted = cparams.prompt.substr(formatted_cumulative_.size());
166+
formatted_cumulative_ = cparams.prompt;
167+
168+
LOG_DBG("formatted: '%s'\n", formatted.c_str());
169+
return formatted;
170+
}
171+
172+
partial_formatter * get_partial_formatter() { return partial_formatter_ptr_.get(); }
173+
174+
private:
175+
std::vector<common_chat_msg> & chat_msgs_;
176+
const common_chat_templates_ptr & chat_templates_;
177+
const common_params & params_;
178+
std::unique_ptr<partial_formatter> partial_formatter_ptr_;
179+
std::string formatted_cumulative_;
180+
};
181+
86182
int main(int argc, char ** argv) {
87183
common_params params;
88184
g_params = &params;
@@ -265,15 +361,7 @@ int main(int argc, char ** argv) {
265361
std::vector<llama_token> embd_inp;
266362

267363
bool waiting_for_first_input = false;
268-
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
269-
common_chat_msg new_msg;
270-
new_msg.role = role;
271-
new_msg.content = content;
272-
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
273-
chat_msgs.push_back(new_msg);
274-
LOG_DBG("formatted: '%s'\n", formatted.c_str());
275-
return formatted;
276-
};
364+
chat_formatter chat_add_and_format(chat_msgs, chat_templates, params);
277365

278366
std::string prompt;
279367
{
@@ -709,6 +797,13 @@ int main(int argc, char ** argv) {
709797

710798
if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) {
711799
assistant_ss << common_token_to_piece(ctx, id, false);
800+
801+
if (auto * formatter = chat_add_and_format.get_partial_formatter()) {
802+
auto outputs = (*formatter)(assistant_ss.str());
803+
for (const auto & out : outputs) {
804+
LOG("%s", out.formatted.c_str());
805+
}
806+
}
712807
}
713808

714809
// echo this to console
@@ -740,8 +835,9 @@ int main(int argc, char ** argv) {
740835
for (auto id : embd) {
741836
const std::string token_str = common_token_to_piece(ctx, id, params.special);
742837

743-
// Console/Stream Output
744-
LOG("%s", token_str.c_str());
838+
if (!chat_add_and_format.get_partial_formatter() || assistant_ss.str().empty()) {
839+
LOG("%s", token_str.c_str());
840+
}
745841

746842
// Record Displayed Tokens To Log
747843
// Note: Generated tokens are created one by one hence this check

0 commit comments

Comments
 (0)