Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 148 additions & 17 deletions tools/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,103 @@ static void sigint_handler(int signo) {
}
#endif

class partial_formatter {
public:
enum output_type {
CONTENT,
REASONING,
};

struct output {
std::string formatted;
output_type type;
};

partial_formatter(const common_chat_syntax & syntax) : syntax_(syntax), had_reasoning_(false) {}

std::vector<output> operator()(const std::string & accumulated) {
common_chat_msg next = common_chat_parse(accumulated, true, syntax_);

auto diffs = common_chat_msg_diff::compute_diffs(previous_, next);
std::vector<output> result;
for (const auto & diff : diffs) {
if (!diff.reasoning_content_delta.empty()) {
result.push_back({diff.reasoning_content_delta, REASONING});
had_reasoning_ = true;
}
if (!diff.content_delta.empty()) {
if (had_reasoning_) {
result.push_back({"\n", REASONING});
had_reasoning_ = false;
}
result.push_back({diff.content_delta, CONTENT});
}
}
previous_ = next;
return result;
}

private:
common_chat_syntax syntax_;
common_chat_msg previous_;
bool had_reasoning_;
};

class chat_formatter {
public:
chat_formatter(
std::vector<common_chat_msg> & chat_msgs,
const common_chat_templates_ptr & chat_templates,
const common_params & params)
: chat_msgs_(chat_msgs),
chat_templates_(chat_templates),
params_(params) {}

std::string operator()(const std::string & role, const std::string & content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
chat_msgs_.push_back(new_msg);

common_chat_templates_inputs cinputs;
cinputs.use_jinja = params_.use_jinja;
cinputs.messages = chat_msgs_;
cinputs.add_generation_prompt = (role == "user");
cinputs.reasoning_format = params_.reasoning_format;

cinputs.enable_thinking =
params_.use_jinja && params_.reasoning_budget != 0 &&
common_chat_templates_support_enable_thinking(chat_templates_.get());

common_chat_params cparams = common_chat_templates_apply(chat_templates_.get(), cinputs);

if (!partial_formatter_ptr_ && params_.reasoning_format != COMMON_REASONING_FORMAT_NONE) {
common_chat_syntax chat_syntax;
chat_syntax.format = cparams.format;
chat_syntax.reasoning_format = params_.reasoning_format;
chat_syntax.thinking_forced_open = cparams.thinking_forced_open;
chat_syntax.parse_tool_calls = false;
partial_formatter_ptr_ = std::make_unique<partial_formatter>(chat_syntax);
}

std::string formatted = cparams.prompt.substr(formatted_cumulative_.size());
formatted_cumulative_ = cparams.prompt;

LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
}

partial_formatter * get_partial_formatter() { return partial_formatter_ptr_.get(); }
const std::string & get_full_prompt() const { return formatted_cumulative_; }

private:
std::vector<common_chat_msg> & chat_msgs_;
const common_chat_templates_ptr & chat_templates_;
const common_params & params_;
std::unique_ptr<partial_formatter> partial_formatter_ptr_;
std::string formatted_cumulative_;
};

int main(int argc, char ** argv) {
common_params params;
g_params = &params;
Expand Down Expand Up @@ -265,17 +362,11 @@ int main(int argc, char ** argv) {
std::vector<llama_token> embd_inp;

bool waiting_for_first_input = false;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg;
new_msg.role = role;
new_msg.content = content;
auto formatted = common_chat_format_single(chat_templates.get(), chat_msgs, new_msg, role == "user", g_params->use_jinja);
chat_msgs.push_back(new_msg);
LOG_DBG("formatted: '%s'\n", formatted.c_str());
return formatted;
};
chat_formatter chat_add_and_format(chat_msgs, chat_templates, params);

std::string prompt;
std::string system_remaining;
std::string prompt_remaining;
{
if (params.conversation_mode && params.enable_chat_template) {
if (!params.system_prompt.empty()) {
Expand All @@ -291,13 +382,9 @@ int main(int argc, char ** argv) {
}

if (!params.system_prompt.empty() || !params.prompt.empty()) {
common_chat_templates_inputs inputs;
inputs.use_jinja = g_params->use_jinja;
inputs.messages = chat_msgs;
inputs.add_generation_prompt = !params.prompt.empty();

prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt;
prompt = chat_add_and_format.get_full_prompt();
}

} else {
// otherwise use the prompt as is
prompt = params.prompt;
Expand All @@ -315,6 +402,19 @@ int main(int argc, char ** argv) {
LOG_DBG("tokens: %s\n", string_from(ctx, embd_inp).c_str());
}

// Set up content tracking to skip template markup during display
bool skip_template_markup = false;
if (params.conversation_mode && params.enable_chat_template) {
for (const auto & msg : chat_msgs) {
if (msg.role == "system") {
system_remaining = msg.content;
} else if (msg.role == "user") {
prompt_remaining = msg.content;
}
}
skip_template_markup = !system_remaining.empty() || !prompt_remaining.empty();
}

// Should not run without any tokens
if (!waiting_for_first_input && embd_inp.empty()) {
if (add_bos) {
Expand Down Expand Up @@ -709,6 +809,13 @@ int main(int argc, char ** argv) {

if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) {
assistant_ss << common_token_to_piece(ctx, id, false);

if (auto * formatter = chat_add_and_format.get_partial_formatter()) {
auto outputs = (*formatter)(assistant_ss.str());
for (const auto & out : outputs) {
LOG("%s", out.formatted.c_str());
}
}
}

// echo this to console
Expand Down Expand Up @@ -740,8 +847,31 @@ int main(int argc, char ** argv) {
for (auto id : embd) {
const std::string token_str = common_token_to_piece(ctx, id, params.special);

// Console/Stream Output
LOG("%s", token_str.c_str());
if (!chat_add_and_format.get_partial_formatter() || assistant_ss.str().empty()) {
if (skip_template_markup) {
if (!token_str.empty() && !system_remaining.empty() &&
system_remaining.compare(0, token_str.length(), token_str) == 0) {

system_remaining.erase(0, token_str.length());
LOG("%s", token_str.c_str());
if (system_remaining.empty()) {
LOG("\n");
}

} else if (!token_str.empty() && !prompt_remaining.empty() &&
prompt_remaining.compare(0, token_str.length(), token_str) == 0) {

prompt_remaining.erase(0, token_str.length());
LOG("%s", token_str.c_str());
if (prompt_remaining.empty()) {
LOG("\n");
}
}

} else {
LOG("%s", token_str.c_str());
}
}

// Record Displayed Tokens To Log
// Note: Generated tokens are created one by one hence this check
Expand All @@ -760,6 +890,7 @@ int main(int argc, char ** argv) {
if (input_echo && (int) embd_inp.size() == n_consumed) {
console::set_display(console::reset);
display = true;
skip_template_markup = false; // system & prompt processing complete
}

// if not currently processing queued inputs;
Expand Down
Loading