Skip to content

Commit 4e6b584

Browse files
committed
support --start-string
1 parent 44cd8d9 commit 4e6b584

File tree

3 files changed

+51
-1
lines changed

3 files changed

+51
-1
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,6 +2837,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
28372837
else { std::invalid_argument("invalid value"); }
28382838
}
28392839
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN}).set_env("LLAMA_ARG_THINK"));
2840+
add_opt(common_arg(
2841+
{"--start-string"}, "STRING",
2842+
"Start outputting tokens only when the start string has been reached",
2843+
[](common_params & params, const std::string & value) {
2844+
params.start_strings.resize(1);
2845+
params.start_strings[0] = value;
2846+
}
2847+
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_START_STRING"));
28402848
add_opt(common_arg(
28412849
{"--chat-template"}, "JINJA_TEMPLATE",
28422850
string_format(

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,7 @@ struct common_params {
366366
bool use_jinja = false; // NOLINT
367367
bool enable_chat_template = true;
368368
common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK;
369+
std::vector<std::string> start_strings;
369370

370371
std::vector<std::string> api_keys;
371372

examples/server/server.cpp

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ struct slot_params {
104104
std::vector<common_adapter_lora_info> lora;
105105

106106
std::vector<std::string> antiprompt;
107+
std::vector<std::string> start_strings;
107108
std::vector<std::string> response_fields;
108109
bool timings_per_token = false;
109110
bool post_sampling_probs = false;
@@ -161,6 +162,7 @@ struct slot_params {
161162
{"mirostat", sampling.mirostat},
162163
{"mirostat_tau", sampling.mirostat_tau},
163164
{"mirostat_eta", sampling.mirostat_eta},
165+
{"start", start_strings},
164166
{"stop", antiprompt},
165167
{"max_tokens", n_predict}, // User configured n_predict
166168
{"n_keep", n_keep},
@@ -229,6 +231,7 @@ struct server_task {
229231
slot_params defaults;
230232
defaults.sampling = params_base.sampling;
231233
defaults.speculative = params_base.speculative;
234+
defaults.start_strings = params_base.start_strings;
232235

233236
// enabling this will output extra debug information in the HTTP responses from the server
234237
params.verbose = params_base.verbosity > 9;
@@ -244,6 +247,7 @@ struct server_task {
244247
//params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
245248
params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms);
246249
params.response_fields = json_value(data, "response_fields", std::vector<std::string>());
250+
params.start_strings = json_value(data, "start_strings", defaults.start_strings);
247251

248252
params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k);
249253
params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p);
@@ -1998,6 +2002,7 @@ struct server_context {
19982002
SLT_INF(slot, "new slot n_ctx_slot = %d\n", slot.n_ctx);
19992003

20002004
slot.params.sampling = params_base.sampling;
2005+
slot.params.start_strings = params_base.start_strings;
20012006

20022007
slot.callback_on_release = [this](int) {
20032008
queue_tasks.pop_deferred_task();
@@ -2192,6 +2197,42 @@ struct server_context {
21922197
const std::string str_test = slot.generated_text.substr(pos);
21932198
bool send_text = true;
21942199

2200+
if(slot.n_sent_text == 0 && slot.has_next_token && !slot.params.start_strings.empty())
2201+
{
2202+
size_t max_start_string_size = 0;
2203+
for(auto start_string: slot.params.start_strings)
2204+
{
2205+
max_start_string_size = std::max(max_start_string_size, start_string.size());
2206+
}
2207+
size_t search_len = max_start_string_size + token_str.size();
2208+
size_t search_pos = 0;
2209+
if(slot.generated_text.size() > search_len)
2210+
{
2211+
search_pos = slot.generated_text.size() - search_len;
2212+
}
2213+
2214+
auto found_pos = slot.generated_text.npos;
2215+
bool found = false;
2216+
std::string found_string;
2217+
for(auto start_string: slot.params.start_strings)
2218+
{
2219+
found_pos = slot.generated_text.find(start_string,search_pos);
2220+
if(found_pos != slot.generated_text.npos) {
2221+
found = true;
2222+
found_string = start_string;
2223+
break;
2224+
}
2225+
}
2226+
2227+
if(found && slot.generated_text.size() > (found_pos + found_string.size()) ) {
2228+
slot.generated_text.erase(
2229+
slot.generated_text.begin(),
2230+
slot.generated_text.begin() + found_pos + found_string.size());
2231+
} else {
2232+
send_text = false;
2233+
}
2234+
}
2235+
21952236
size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true);
21962237
if (stop_pos != std::string::npos) {
21972238
slot.generated_text.erase(
@@ -2200,7 +2241,7 @@ struct server_context {
22002241
pos = std::min(slot.n_sent_text, slot.generated_text.size());
22012242
} else if (slot.has_next_token) {
22022243
stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false);
2203-
send_text = stop_pos == std::string::npos;
2244+
send_text = send_text && stop_pos == std::string::npos;
22042245
}
22052246

22062247
// check if there is any token to predict

0 commit comments

Comments
 (0)