Skip to content

Commit d0dfbf9

Browse files
authored
Bug fixes for completions and prompt caching in server (#906)
* Bug fixes for completions and prompt caching in server * Fix compiler warning about redefinition --------- Co-authored-by: firecoperana <firecoperana>
1 parent 320fc60 commit d0dfbf9

File tree

2 files changed

+90
-57
lines changed

2 files changed

+90
-57
lines changed

examples/server/server.cpp

Lines changed: 28 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,7 @@
1515
// crash the server in debug mode, otherwise send an http 500 error
1616
#define CPPHTTPLIB_NO_EXCEPTIONS 1
1717
#endif
18-
// increase max payload length to allow use of larger context size
19-
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
20-
// disable Nagle's algorithm
21-
#define CPPHTTPLIB_TCP_NODELAY true
22-
#include "httplib.h"
23-
// Change JSON_ASSERT from assert() to GGML_ASSERT:
24-
#define JSON_ASSERT GGML_ASSERT
18+
2519
#include <nlohmann/json.hpp>
2620
#include "index.html.gz.hpp"
2721
#include "index_llamacpp.html.gz.hpp"
@@ -3050,7 +3044,7 @@ struct server_context {
30503044
GGML_ASSERT(slot.ga_n == 1);
30513045

30523046
// reuse any previously computed tokens that are common with the new prompt
3053-
slot.n_past = common_part(slot.cache_tokens.tokens_data(), prompt_tokens.tokens_data());
3047+
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
30543048

30553049
// push the prompt into the sampling context (do not apply grammar)
30563050
for (int i = 0; i < slot.n_past; ++i) {
@@ -3137,7 +3131,6 @@ struct server_context {
31373131
{
31383132
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past);
31393133
slot.cache_tokens.push_back(chunk.get()); // copy
3140-
fprintf(stdout, slot.cache_tokens.detokenize(ctx, true).c_str());
31413134
}
31423135

31433136
slot.n_past += n_pos;
@@ -4293,14 +4286,15 @@ int main(int argc, char ** argv) {
42934286
}
42944287

42954288
const auto& prompt = data.at("prompt");
4296-
fprintf(stdout, prompt.get<std::string>().c_str());
42974289

42984290
// process prompt
42994291
std::vector<server_tokens> inputs;
43004292

43014293
if (oaicompat && ctx_server.mctx != nullptr) {
43024294
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
4303-
printFilesInfo(files);
4295+
#ifndef NDEBUG
4296+
print_files_info(files);
4297+
#endif // !NDEBUG
43044298
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
43054299
}
43064300
else {
@@ -4346,46 +4340,46 @@ int main(int argc, char ** argv) {
43464340
if (!result.error) {
43474341
result.oaicompat = oaicompat;
43484342
result.oaicompat_cmpl_id = completion_id;
4349-
json result_array;
4343+
json res_json;
43504344
if (oaicompat) {
43514345
if (result.final_result) {
4352-
result_array = result.to_json_final();
4346+
res_json = result.to_json_final();
43534347
}
43544348
else {
4355-
result_array = result.to_json_partial();
4349+
res_json = result.to_json_partial();
43564350
}
43574351
}
43584352
else {
43594353
// legacy completions
4360-
result_array = result.data;
4354+
res_json = result.data;
43614355
}
4362-
if (result_array.is_array()) {
4363-
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
4364-
if (!it->empty()) {
4365-
const std::string str =
4366-
"data: " +
4367-
it->dump(-1, ' ', false, json::error_handler_t::replace) +
4368-
"\n\n";
4369-
LOG_VERBOSE("data stream", { {"to_send", str} });
4370-
if (!sink.write(str.c_str(), str.size())) {
4371-
ctx_server.queue_results.remove_waiting_task_id(id_task);
4372-
return false;
4373-
}
4356+
if (res_json.is_array()) {
4357+
// chat completions and oai completions
4358+
for (const auto& res : res_json) {
4359+
if (!server_sent_event(sink, res)) {
4360+
// sending failed (HTTP connection closed), cancel the generation
4361+
ctx_server.queue_results.remove_waiting_task_id(id_task);
4362+
return false;
43744363
}
43754364
}
43764365
if (result.stop) {
43774366
successful_completion = true;
43784367
break;
43794368
}
43804369
}
4370+
else {
4371+
// legacy completions
4372+
if (!server_sent_event(sink, res_json)) {
4373+
ctx_server.queue_results.remove_waiting_task_id(id_task);
4374+
return false;
4375+
}
4376+
if (result.stop) {
4377+
break;
4378+
}
4379+
}
43814380
}
43824381
else {
4383-
const std::string str =
4384-
"error: " +
4385-
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
4386-
"\n\n";
4387-
LOG_VERBOSE("data stream", { {"to_send", str} });
4388-
if (!sink.write(str.c_str(), str.size())) {
4382+
if (!server_sent_event(sink, result.data)) {
43894383
ctx_server.queue_results.remove_waiting_task_id(id_task);
43904384
return false;
43914385
}
@@ -4436,7 +4430,7 @@ int main(int argc, char ** argv) {
44364430
data,
44374431
files,
44384432
res,
4439-
OAICOMPAT_TYPE_CHAT);
4433+
OAICOMPAT_TYPE_COMPLETION);
44404434
};
44414435

44424436
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {

examples/server/utils.hpp

Lines changed: 62 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,16 @@
1515
#include <sstream>
1616
#include <random>
1717

18+
// increase max payload length to allow use of larger context size
19+
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
20+
// increase backlog size to avoid connection resets for >> 1 slots
21+
#define CPPHTTPLIB_LISTEN_BACKLOG 512
22+
// increase max URI length to handle longer prompts in query string
23+
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
24+
// disable Nagle's algorithm
25+
#define CPPHTTPLIB_TCP_NODELAY true
26+
#include "httplib.h"
27+
1828
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
1929

2030
using json = nlohmann::ordered_json;
@@ -411,6 +421,17 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
411421
return out;
412422
}
413423

424+
static bool server_sent_event(httplib::DataSink& sink, const json& data) {
425+
const std::string str =
426+
"data: " +
427+
data.dump(-1, ' ', false, json::error_handler_t::replace) +
428+
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
429+
430+
LOG_VERBOSE("data stream, to_send: %s", str.c_str());
431+
432+
return sink.write(str.c_str(), str.size());
433+
}
434+
414435
//
415436
// OAI utils
416437
//
@@ -1065,7 +1086,6 @@ struct server_tokens {
10651086
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
10661087
GGML_ASSERT(has_mtmd);
10671088
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
1068-
fprintf(stdout, "n_pos: %d\n", n_pos);
10691089
llama_pos start_pos = tokens.size();
10701090
for (int i = 0; i < n_pos; ++i) {
10711091
tokens.emplace_back(LLAMA_TOKEN_NULL);
@@ -1209,39 +1229,54 @@ struct server_tokens {
12091229
}
12101230

12111231
size_t get_common_prefix(const server_tokens& b) const {
1212-
size_t max_idx = std::min(tokens.size(), b.tokens.size());
1232+
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
1233+
1234+
if (!has_mtmd) {
1235+
for (size_t i = 0; i < max_idx; ++i) {
1236+
if (tokens[i] == b.tokens[i]) {
1237+
continue;
1238+
}
1239+
return i;
1240+
}
1241+
return max_idx;
1242+
}
1243+
12131244
for (size_t i = 0; i < max_idx; ++i) {
1214-
auto& ai = tokens[i];
1215-
auto& bi = b.tokens[i];
1245+
const llama_token ai = tokens[i];
1246+
const llama_token bi = b.tokens[i];
12161247

12171248
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
1218-
GGML_ASSERT(has_mtmd);
12191249
const auto& a_chunk = find_chunk(i);
12201250
const auto& b_chunk = b.find_chunk(i);
1251+
12211252
GGML_ASSERT(a_chunk && b_chunk);
1222-
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
1223-
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
1224-
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
1225-
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
1226-
if (ai_id == bi_id && a_pos == b_pos) {
1227-
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
1228-
i += a_pos - 1; // will be +1 by the for loop
1253+
1254+
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
1255+
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
1256+
1257+
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
1258+
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
1259+
1260+
if (id_ai == id_bi && pos_a == pos_b) {
1261+
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
1262+
i += pos_a - 1; // will be +1 by the for loop
12291263
continue;
12301264
}
1231-
else {
1232-
return i;
1233-
}
1265+
1266+
return i;
12341267
}
1235-
else if (ai == bi) {
1268+
1269+
if (ai == bi) {
12361270
continue;
12371271
}
1238-
else {
1239-
return i;
1240-
}
1272+
1273+
return i;
12411274
}
1275+
12421276
return max_idx; // all tokens are equal
12431277
}
12441278

1279+
12451280
// make sure all text tokens are within the vocab range
12461281
bool validate(const struct llama_context* ctx) const {
12471282
const llama_model* model = llama_get_model(ctx);
@@ -1274,10 +1309,12 @@ struct server_tokens {
12741309
llama_pos n_past,
12751310
int32_t seq_id,
12761311
llama_pos& n_pos_out) {
1312+
char buffer[512];
12771313
auto& chunk = find_chunk(n_past);
12781314
const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
12791315
? "image" : "audio";
1280-
LOG_INFO("processing %s...\n", name);
1316+
snprintf(buffer, 512, "processing : %s",name);
1317+
LOG_INFO(buffer, {});
12811318
int32_t n_batch = llama_n_batch(ctx);
12821319
int64_t t0 = ggml_time_ms();
12831320
llama_pos new_n_past = n_past;
@@ -1288,9 +1325,11 @@ struct server_tokens {
12881325
n_batch,
12891326
true, // logits last
12901327
&new_n_past);
1291-
LOG_INFO("processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
1328+
snprintf(buffer, 512, "processed in %d ms", ggml_time_ms() - t0);
1329+
LOG_INFO(buffer, {});
12921330
if (result != 0) {
1293-
LOG_ERROR("mtmd_helper_eval failed with status %d", result);
1331+
snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result);
1332+
LOG_ERROR(buffer, {});
12941333
n_pos_out = n_past;
12951334
return result;
12961335
}
@@ -1422,7 +1461,7 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* voca
14221461
return result;
14231462
}
14241463
// Assuming raw_buffer has .data() and .size() members
1425-
inline void printFilesInfo(const std::vector<raw_buffer>& files) {
1464+
inline void print_files_info(const std::vector<raw_buffer>& files) {
14261465
for (size_t i = 0; i < files.size(); ++i) {
14271466
const auto& file = files[i];
14281467
std::cout << "File " << i << ": Size = " << file.size() << " bytes\n";

0 commit comments

Comments
 (0)