Skip to content

Commit 9d7d5f2

Browse files
committed
cli : auto activate conversation mode if chat template is detected
1 parent 1244cdc commit 9d7d5f2

File tree

3 files changed

+47
-19
lines changed

3 files changed

+47
-19
lines changed

common/arg.cpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -768,15 +768,19 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
768768
).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER}));
769769
add_opt(common_arg(
770770
{"-cnv", "--conversation"},
771-
string_format(
772-
"run in conversation mode:\n"
773-
"- does not print special tokens and suffix/prefix\n"
774-
"- interactive mode is also enabled\n"
775-
"(default: %s)",
776-
params.conversation ? "true" : "false"
777-
),
771+
"run in conversation mode:\n"
772+
"- does not print special tokens and suffix/prefix\n"
773+
"- interactive mode is also enabled\n"
774+
"(default: auto enabled if chat template is available)",
775+
[](common_params & params) {
776+
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
777+
}
778+
).set_examples({LLAMA_EXAMPLE_MAIN}));
779+
add_opt(common_arg(
780+
{"-no-cnv", "--no-conversation"},
781+
"force disable conversation mode (default: false)",
778782
[](common_params & params) {
779-
params.conversation = true;
783+
params.conversation_mode = COMMON_CONVERSATION_MODE_DISABLED;
780784
}
781785
).set_examples({LLAMA_EXAMPLE_MAIN}));
782786
add_opt(common_arg(

common/common.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ enum dimre_method {
103103
DIMRE_METHOD_MEAN,
104104
};
105105

106+
enum common_conversation_mode {
107+
COMMON_CONVERSATION_MODE_DISABLED = 0,
108+
COMMON_CONVERSATION_MODE_ENABLED = 1,
109+
COMMON_CONVERSATION_MODE_AUTO = 2,
110+
};
111+
106112
// sampling parameters
107113
struct common_params_sampling {
108114
uint32_t seed = LLAMA_DEFAULT_SEED; // the seed used to initialize llama_sampler
@@ -275,7 +281,6 @@ struct common_params {
275281
bool special = false; // enable special token output
276282
bool interactive = false; // interactive mode
277283
bool interactive_first = false; // wait for user input immediately
278-
bool conversation = false; // conversation mode (does not print special tokens and suffix/prefix)
279284
bool prompt_cache_all = false; // save user input and generations to prompt cache
280285
bool prompt_cache_ro = false; // open the prompt cache read-only and do not update it
281286

@@ -301,6 +306,8 @@ struct common_params {
301306
ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K
302307
ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V
303308

309+
common_conversation_mode conversation_mode = COMMON_CONVERSATION_MODE_AUTO;
310+
304311
// multimodal models (see examples/llava)
305312
std::string mmproj = ""; // path to multimodal projector // NOLINT
306313
std::vector<std::string> image; // path to image file(s)

examples/main/main.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@
3030
#pragma warning(disable: 4244 4267) // possible loss of data
3131
#endif
3232

33+
static const char * DEFAULT_SYSTEM_MESSAGE = "You are a helpful assistant";
34+
3335
static llama_context ** g_ctx;
3436
static llama_model ** g_model;
3537
static common_sampler ** g_smpl;
@@ -204,8 +206,17 @@ int main(int argc, char ** argv) {
204206
LOG_WRN("%s: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx);
205207
}
206208

209+
// auto enable conversation mode if chat template is available
210+
if (
211+
params.conversation_mode == COMMON_CONVERSATION_MODE_AUTO
212+
&& (!common_get_builtin_chat_template(model).empty() || !params.chat_template.empty())
213+
) {
214+
LOG_INF("%s: chat template is available, enabling conversation mode (disable it with -no-cnv)\n", __func__);
215+
params.conversation_mode = COMMON_CONVERSATION_MODE_ENABLED;
216+
}
217+
207218
// print chat template example in conversation mode
208-
if (params.conversation) {
219+
if (params.conversation_mode) {
209220
if (params.enable_chat_template) {
210221
LOG_INF("%s: chat template example:\n%s\n", __func__, common_chat_format_example(model, params.chat_template).c_str());
211222
} else {
@@ -252,8 +263,10 @@ int main(int argc, char ** argv) {
252263
std::vector<llama_token> embd_inp;
253264

254265
{
255-
auto prompt = (params.conversation && params.enable_chat_template && !params.prompt.empty())
256-
? chat_add_and_format(model, chat_msgs, "system", params.prompt) // format the system prompt in conversation mode
266+
auto prompt = (params.conversation_mode && params.enable_chat_template)
267+
// format the system prompt in conversation mode (fallback to default if empty)
268+
? chat_add_and_format(model, chat_msgs, "system", params.prompt.empty() ? DEFAULT_SYSTEM_MESSAGE : params.prompt)
269+
// otherwise use the prompt as is
257270
: params.prompt;
258271
if (params.interactive_first || !params.prompt.empty() || session_tokens.empty()) {
259272
LOG_DBG("tokenize the prompt\n");
@@ -327,7 +340,7 @@ int main(int argc, char ** argv) {
327340
params.n_keep += add_bos; // always keep the BOS token
328341
}
329342

330-
if (params.conversation) {
343+
if (params.conversation_mode) {
331344
params.interactive_first = true;
332345
}
333346

@@ -451,7 +464,11 @@ int main(int argc, char ** argv) {
451464
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
452465
LOG_INF( " - Press Ctrl+C to interject at any time.\n");
453466
#endif
454-
LOG_INF( "%s\n", control_message);
467+
LOG_INF( "%s", control_message);
468+
if (params.conversation_mode && params.enable_chat_template && params.prompt.empty()) {
469+
LOG_INF( " - Using default system message. To change it, set a different value via -p PROMPT or -f FILE argument.\n");
470+
}
471+
LOG_INF("\n");
455472

456473
is_interacting = params.interactive_first;
457474
}
@@ -763,15 +780,15 @@ int main(int argc, char ** argv) {
763780
}
764781

765782
// if current token is not EOG, we add it to current assistant message
766-
if (params.conversation) {
783+
if (params.conversation_mode) {
767784
const auto id = common_sampler_last(smpl);
768785
assistant_ss << common_token_to_piece(ctx, id, false);
769786
}
770787

771788
if (n_past > 0 && is_interacting) {
772789
LOG_DBG("waiting for user input\n");
773790

774-
if (params.conversation) {
791+
if (params.conversation_mode) {
775792
LOG("\n> ");
776793
}
777794

@@ -781,7 +798,7 @@ int main(int argc, char ** argv) {
781798
}
782799

783800
std::string buffer;
784-
if (!params.input_prefix.empty() && !params.conversation) {
801+
if (!params.input_prefix.empty() && !params.conversation_mode) {
785802
LOG_DBG("appending input prefix: '%s'\n", params.input_prefix.c_str());
786803
LOG("%s", params.input_prefix.c_str());
787804
}
@@ -805,7 +822,7 @@ int main(int argc, char ** argv) {
805822
// Entering a empty line lets the user pass control back
806823
if (buffer.length() > 1) {
807824
// append input suffix if any
808-
if (!params.input_suffix.empty() && !params.conversation) {
825+
if (!params.input_suffix.empty() && !params.conversation_mode) {
809826
LOG_DBG("appending input suffix: '%s'\n", params.input_suffix.c_str());
810827
LOG("%s", params.input_suffix.c_str());
811828
}
@@ -818,7 +835,7 @@ int main(int argc, char ** argv) {
818835
string_process_escapes(buffer);
819836
}
820837

821-
bool format_chat = params.conversation && params.enable_chat_template;
838+
bool format_chat = params.conversation_mode && params.enable_chat_template;
822839
std::string user_inp = format_chat
823840
? chat_add_and_format(model, chat_msgs, "user", std::move(buffer))
824841
: std::move(buffer);

0 commit comments

Comments
 (0)