Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions common/arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -949,6 +949,15 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-st", "--single-turn"},
"run conversation for a single turn only, then exit when done\n"
"will not be interactive if first turn is predefined with --prompt\n"
"(default: false)",
[](common_params & params) {
params.single_turn = true;
}
).set_examples({LLAMA_EXAMPLE_MAIN}));
add_opt(common_arg(
{"-i", "--interactive"},
string_format("run in interactive mode (default: %s)", params.interactive ? "true" : "false"),
Expand Down
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,8 @@ struct common_params {
bool warmup = true; // warmup run
bool check_tensors = false; // validate tensor data

bool single_turn = false; // single turn chat conversation

ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V

Expand Down
47 changes: 38 additions & 9 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,8 +217,8 @@ int main(int argc, char ** argv) {
// print chat template example in conversation mode
if (params.conversation_mode) {
if (params.enable_chat_template) {
if (!params.prompt.empty()) {
LOG_WRN("*** User-specified prompt in conversation mode will be ignored, did you mean to set --system-prompt (-sys) instead?\n");
if (!params.prompt.empty() && params.system_prompt.empty()) {
LOG_WRN("*** User-specified prompt will pre-start conversation, did you mean to set --system-prompt (-sys) instead?\n");
}

LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(chat_templates.get(), params.use_jinja).c_str());
Expand Down Expand Up @@ -265,7 +265,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd_inp;

bool waiting_for_first_input = params.conversation_mode && params.enable_chat_template && params.system_prompt.empty();
bool waiting_for_first_input = false;
auto chat_add_and_format = [&chat_msgs, &chat_templates](const std::string & role, const std::string & content) {
common_chat_msg new_msg;
new_msg.role = role;
Expand All @@ -276,22 +276,36 @@ int main(int argc, char ** argv) {
return formatted;
};

std::string prompt;
{
std::string prompt;

if (params.conversation_mode && params.enable_chat_template) {
// format the system prompt in conversation mode (will use template default if empty)
prompt = params.system_prompt;

if (!prompt.empty()) {
prompt = chat_add_and_format("system", prompt);
// format the system prompt (will use template default if empty)
chat_add_and_format("system", prompt);
}

if (!params.prompt.empty()) {
// format and append the user prompt
chat_add_and_format("user", params.prompt);
} else {
waiting_for_first_input = true;
}

if (!prompt.empty() || !params.prompt.empty()) {
common_chat_templates_inputs inputs;
inputs.messages = chat_msgs;
inputs.add_generation_prompt = !params.prompt.empty();

prompt = common_chat_templates_apply(chat_templates.get(), inputs).prompt;
}
} else {
// otherwise use the prompt as is
prompt = params.prompt;
}

if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
if (params.interactive_first || !prompt.empty() || session_tokens.empty()) {
LOG_DBG("tokenize the prompt\n");
embd_inp = common_tokenize(ctx, prompt, true, true);
} else {
Expand All @@ -304,7 +318,7 @@ int main(int argc, char ** argv) {
}

// Should not run without any tokens
if (!params.conversation_mode && embd_inp.empty()) {
if (!waiting_for_first_input && embd_inp.empty()) {
if (add_bos) {
embd_inp.push_back(llama_vocab_bos(vocab));
LOG_WRN("embd_inp was considered empty and bos was added: %s\n", string_from(ctx, embd_inp).c_str());
Expand Down Expand Up @@ -372,6 +386,11 @@ int main(int argc, char ** argv) {
params.interactive = true;
}

if (params.single_turn && !params.prompt.empty()) {
params.interactive = false;
params.interactive_first = false;
}

if (params.verbose_prompt) {
LOG_INF("%s: prompt: '%s'\n", __func__, params.prompt.c_str());
LOG_INF("%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
Expand Down Expand Up @@ -808,6 +827,11 @@ int main(int argc, char ** argv) {
if (params.conversation_mode && !waiting_for_first_input) {
const auto id = common_sampler_last(smpl);
assistant_ss << common_token_to_piece(ctx, id, false);

if (!prompt.empty()) {
prompt.clear();
is_interacting = false;
}
}

if ((n_past > 0 || waiting_for_first_input) && is_interacting) {
Expand Down Expand Up @@ -905,6 +929,11 @@ int main(int argc, char ** argv) {
common_sampler_reset(smpl);
}
is_interacting = false;

if (waiting_for_first_input && params.single_turn) {
params.interactive = false;
params.interactive_first = false;
}
waiting_for_first_input = false;
}
}
Expand Down
Loading