@@ -105,6 +105,7 @@ struct slot_params {
105105
106106 std::vector<std::string> antiprompt;
107107 std::vector<std::string> start_strings;
108+ size_t start_string_max_len;
108109 std::vector<std::string> response_fields;
109110 bool timings_per_token = false ;
110111 bool post_sampling_probs = false ;
@@ -247,8 +248,7 @@ struct server_task {
247248 // params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement
248249 params.t_max_predict_ms = json_value (data, " t_max_predict_ms" , defaults.t_max_predict_ms );
249250 params.response_fields = json_value (data, " response_fields" , std::vector<std::string>());
250- params.start_strings = json_value (data, " start_strings" , defaults.start_strings );
251-
251+
252252 params.sampling .top_k = json_value (data, " top_k" , defaults.sampling .top_k );
253253 params.sampling .top_p = json_value (data, " top_p" , defaults.sampling .top_p );
254254 params.sampling .min_p = json_value (data, " min_p" , defaults.sampling .min_p );
@@ -282,6 +282,14 @@ struct server_task {
282282 params.speculative .n_min = std::max (params.speculative .n_min , 0 );
283283 params.speculative .n_max = std::max (params.speculative .n_max , 0 );
284284
285+ // start strings
286+ params.start_strings = json_value (data, " start_strings" , defaults.start_strings );
287+ params.start_string_max_len = 0 ;
288+ for (auto start_string: params.start_strings ) {
289+ params.start_string_max_len = std::max (params.start_string_max_len , start_string.size ());
290+ }
291+
292+
285293 // Use OpenAI API logprobs only if n_probs wasn't provided
286294 if (data.contains (" logprobs" ) && params.sampling .n_probs == defaults.sampling .n_probs ){
287295 params.sampling .n_probs = json_value (data, " logprobs" , defaults.sampling .n_probs );
@@ -1295,6 +1303,8 @@ struct server_slot {
12951303
12961304 std::string stopping_word;
12971305
1306+ bool start_string_found = false ;
1307+
12981308 // sampling
12991309 json json_schema;
13001310
@@ -1332,6 +1342,7 @@ struct server_slot {
13321342 n_past = 0 ;
13331343 n_sent_text = 0 ;
13341344 task_type = SERVER_TASK_TYPE_COMPLETION;
1345+ start_string_found = false ;
13351346
13361347 generated_tokens.clear ();
13371348 generated_token_probs.clear ();
@@ -2197,11 +2208,8 @@ struct server_context {
21972208 const std::string str_test = slot.generated_text .substr (pos);
21982209 bool send_text = true ;
21992210
2200- if (slot.n_sent_text == 0 && slot.has_next_token && !slot.params .start_strings .empty ()) {
2201- size_t max_start_string_size = 0 ;
2202- for (auto start_string: slot.params .start_strings ) {
2203- max_start_string_size = std::max (max_start_string_size, start_string.size ());
2204- }
2211+ if (!slot.start_string_found && slot.has_next_token && !slot.params .start_strings .empty ()) {
2212+ size_t max_start_string_size = slot.params .start_string_max_len ;
22052213 size_t search_len = max_start_string_size + token_str.size ();
22062214 size_t search_pos = 0 ;
22072215 if (slot.generated_text .size () > search_len) {
@@ -2224,6 +2232,7 @@ struct server_context {
22242232 slot.generated_text .erase (
22252233 slot.generated_text .begin (),
22262234 slot.generated_text .begin () + found_pos + found_string.size ());
2235+ slot.start_string_found = true ;
22272236 } else {
22282237 send_text = false ;
22292238 }
0 commit comments