@@ -104,6 +104,7 @@ struct slot_params {
104104 std::vector<common_adapter_lora_info> lora;
105105
106106 std::vector<std::string> antiprompt;
107+ std::vector<std::string> start_strings;
107108 std::vector<std::string> response_fields;
108109 bool timings_per_token = false ;
109110 bool post_sampling_probs = false ;
@@ -161,6 +162,7 @@ struct slot_params {
161162 {" mirostat" , sampling.mirostat },
162163 {" mirostat_tau" , sampling.mirostat_tau },
163164 {" mirostat_eta" , sampling.mirostat_eta },
165+ {" start" , start_strings},
164166 {" stop" , antiprompt},
165167 {" max_tokens" , n_predict}, // User configured n_predict
166168 {" n_keep" , n_keep},
@@ -229,6 +231,7 @@ struct server_task {
229231 slot_params defaults;
230232 defaults.sampling = params_base.sampling ;
231233 defaults.speculative = params_base.speculative ;
234+ defaults.start_strings = params_base.start_strings ;
232235
233236 // enabling this will output extra debug information in the HTTP responses from the server
234237 params.verbose = params_base.verbosity > 9 ;
@@ -244,6 +247,7 @@ struct server_task {
244247 // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
245248 params.t_max_predict_ms = json_value (data, " t_max_predict_ms" , defaults.t_max_predict_ms );
246249 params.response_fields = json_value (data, " response_fields" , std::vector<std::string>());
250+ params.start_strings = json_value (data, " start_strings" , defaults.start_strings );
247251
248252 params.sampling .top_k = json_value (data, " top_k" , defaults.sampling .top_k );
249253 params.sampling .top_p = json_value (data, " top_p" , defaults.sampling .top_p );
@@ -1998,6 +2002,7 @@ struct server_context {
19982002 SLT_INF (slot, " new slot n_ctx_slot = %d\n " , slot.n_ctx );
19992003
20002004 slot.params .sampling = params_base.sampling ;
2005+ slot.params .start_strings = params_base.start_strings ;
20012006
20022007 slot.callback_on_release = [this ](int ) {
20032008 queue_tasks.pop_deferred_task ();
@@ -2192,6 +2197,42 @@ struct server_context {
21922197 const std::string str_test = slot.generated_text .substr (pos);
21932198 bool send_text = true ;
21942199
2200+ if (slot.n_sent_text == 0 && slot.has_next_token && !slot.params .start_strings .empty ())
2201+ {
2202+ size_t max_start_string_size = 0 ;
2203+ for (auto start_string: slot.params .start_strings )
2204+ {
2205+ max_start_string_size = std::max (max_start_string_size, start_string.size ());
2206+ }
2207+ size_t search_len = max_start_string_size + token_str.size ();
2208+ size_t search_pos = 0 ;
2209+ if (slot.generated_text .size () > search_len)
2210+ {
2211+ search_pos = slot.generated_text .size () - search_len;
2212+ }
2213+
2214+ auto found_pos = slot.generated_text .npos ;
2215+ bool found = false ;
2216+ std::string found_string;
2217+ for (auto start_string: slot.params .start_strings )
2218+ {
2219+ found_pos = slot.generated_text .find (start_string,search_pos);
2220+ if (found_pos != slot.generated_text .npos ) {
2221+ found = true ;
2222+ found_string = start_string;
2223+ break ;
2224+ }
2225+ }
2226+
2227+ if (found && slot.generated_text .size () > (found_pos + found_string.size ()) ) {
2228+ slot.generated_text .erase (
2229+ slot.generated_text .begin (),
2230+ slot.generated_text .begin () + found_pos + found_string.size ());
2231+ } else {
2232+ send_text = false ;
2233+ }
2234+ }
2235+
21952236 size_t stop_pos = slot.find_stopping_strings (str_test, token_str.size (), true );
21962237 if (stop_pos != std::string::npos) {
21972238 slot.generated_text .erase (
@@ -2200,7 +2241,7 @@ struct server_context {
22002241 pos = std::min (slot.n_sent_text , slot.generated_text .size ());
22012242 } else if (slot.has_next_token ) {
22022243 stop_pos = slot.find_stopping_strings (str_test, token_str.size (), false );
2203- send_text = stop_pos == std::string::npos;
2244+ send_text = send_text && stop_pos == std::string::npos;
22042245 }
22052246
22062247 // check if there is any token to predict
0 commit comments