diff --git a/examples/server/server.cpp b/examples/server/server.cpp index be3968901..40fe9f1d9 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -788,7 +788,7 @@ struct server_slot { pos = text.find(word, from_pos); } else { - pos = string_find_partial_stop(word, text); + pos = string_find_partial_stop(text, word); } if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { @@ -1960,31 +1960,28 @@ struct server_context { size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); const std::string str_test = slot.generated_text.substr(pos); - bool is_stop_full = false; + bool send_text = true; size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); if (stop_pos != std::string::npos) { - is_stop_full = true; slot.generated_text.erase( slot.generated_text.begin() + pos + stop_pos, slot.generated_text.end()); - // Update n_sent_text to not exceed the new generated_text size - slot.n_sent_text = std::min(slot.n_sent_text, slot.generated_text.size()); - pos = slot.n_sent_text; - } else { - is_stop_full = false; - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + pos = std::min(slot.n_sent_text, slot.generated_text.size()); + } + else if (slot.has_next_token && !llama_token_is_eog(model, result.tok)) { + stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); + send_text = stop_pos == std::string::npos; } // check if there is any token to predict - if (stop_pos == std::string::npos || (!slot.has_next_token && !is_stop_full && stop_pos > 0)) { + if (send_text) { // no send the stop word in the response result.text_to_send = slot.generated_text.substr(pos, std::string::npos); slot.n_sent_text += result.text_to_send.size(); // add the token to slot queue and cache - } else if (stop_pos != std::string::npos) { - // Handle partial stop - update n_sent_text to the end of the current text - slot.n_sent_text = slot.generated_text.size(); + } else { + result.text_to_send = ""; } slot.add_token_string(result);