Skip to content

Commit f519e17

Browse files
committed
Revert "Bug fixes for completions and prompt caching in server (ikawrakow#906)"
This reverts commit d0dfbf9.
1 parent 08e3381 commit f519e17

File tree

2 files changed

+57
-90
lines changed

2 files changed

+57
-90
lines changed

examples/server/server.cpp

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,13 @@
1515
// crash the server in debug mode, otherwise send an http 500 error
1616
#define CPPHTTPLIB_NO_EXCEPTIONS 1
1717
#endif
18-
18+
// increase max payload length to allow use of larger context size
19+
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
20+
// disable Nagle's algorithm
21+
#define CPPHTTPLIB_TCP_NODELAY true
22+
#include "httplib.h"
23+
// Change JSON_ASSERT from assert() to GGML_ASSERT:
24+
#define JSON_ASSERT GGML_ASSERT
1925
#include <nlohmann/json.hpp>
2026
#include "index.html.gz.hpp"
2127
#include "index_llamacpp.html.gz.hpp"
@@ -3044,7 +3050,7 @@ struct server_context {
30443050
GGML_ASSERT(slot.ga_n == 1);
30453051

30463052
// reuse any previously computed tokens that are common with the new prompt
3047-
slot.n_past = slot.cache_tokens.get_common_prefix(prompt_tokens);
3053+
slot.n_past = common_part(slot.cache_tokens.tokens_data(), prompt_tokens.tokens_data());
30483054

30493055
// push the prompt into the sampling context (do not apply grammar)
30503056
for (int i = 0; i < slot.n_past; ++i) {
@@ -3131,6 +3137,7 @@ struct server_context {
31313137
{
31323138
const auto& chunk = slot.prompt_tokens.find_chunk(slot.n_past);
31333139
slot.cache_tokens.push_back(chunk.get()); // copy
3140+
fprintf(stdout, slot.cache_tokens.detokenize(ctx, true).c_str());
31343141
}
31353142

31363143
slot.n_past += n_pos;
@@ -4286,15 +4293,14 @@ int main(int argc, char ** argv) {
42864293
}
42874294

42884295
const auto& prompt = data.at("prompt");
4296+
fprintf(stdout, prompt.get<std::string>().c_str());
42894297

42904298
// process prompt
42914299
std::vector<server_tokens> inputs;
42924300

42934301
if (oaicompat && ctx_server.mctx != nullptr) {
42944302
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
4295-
#ifndef NDEBUG
4296-
print_files_info(files);
4297-
#endif // !NDEBUG
4303+
printFilesInfo(files);
42984304
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
42994305
}
43004306
else {
@@ -4340,46 +4346,46 @@ int main(int argc, char ** argv) {
43404346
if (!result.error) {
43414347
result.oaicompat = oaicompat;
43424348
result.oaicompat_cmpl_id = completion_id;
4343-
json res_json;
4349+
json result_array;
43444350
if (oaicompat) {
43454351
if (result.final_result) {
4346-
res_json = result.to_json_final();
4352+
result_array = result.to_json_final();
43474353
}
43484354
else {
4349-
res_json = result.to_json_partial();
4355+
result_array = result.to_json_partial();
43504356
}
43514357
}
43524358
else {
43534359
// legacy completions
4354-
res_json = result.data;
4360+
result_array = result.data;
43554361
}
4356-
if (res_json.is_array()) {
4357-
// chat completions and oai completions
4358-
for (const auto& res : res_json) {
4359-
if (!server_sent_event(sink, res)) {
4360-
// sending failed (HTTP connection closed), cancel the generation
4361-
ctx_server.queue_results.remove_waiting_task_id(id_task);
4362-
return false;
4362+
if (result_array.is_array()) {
4363+
for (auto it = result_array.begin(); it != result_array.end(); ++it) {
4364+
if (!it->empty()) {
4365+
const std::string str =
4366+
"data: " +
4367+
it->dump(-1, ' ', false, json::error_handler_t::replace) +
4368+
"\n\n";
4369+
LOG_VERBOSE("data stream", { {"to_send", str} });
4370+
if (!sink.write(str.c_str(), str.size())) {
4371+
ctx_server.queue_results.remove_waiting_task_id(id_task);
4372+
return false;
4373+
}
43634374
}
43644375
}
43654376
if (result.stop) {
43664377
successful_completion = true;
43674378
break;
43684379
}
43694380
}
4370-
else {
4371-
// legacy completions
4372-
if (!server_sent_event(sink, res_json)) {
4373-
ctx_server.queue_results.remove_waiting_task_id(id_task);
4374-
return false;
4375-
}
4376-
if (result.stop) {
4377-
break;
4378-
}
4379-
}
43804381
}
43814382
else {
4382-
if (!server_sent_event(sink, result.data)) {
4383+
const std::string str =
4384+
"error: " +
4385+
result.data.dump(-1, ' ', false, json::error_handler_t::replace) +
4386+
"\n\n";
4387+
LOG_VERBOSE("data stream", { {"to_send", str} });
4388+
if (!sink.write(str.c_str(), str.size())) {
43834389
ctx_server.queue_results.remove_waiting_task_id(id_task);
43844390
return false;
43854391
}
@@ -4430,7 +4436,7 @@ int main(int argc, char ** argv) {
44304436
data,
44314437
files,
44324438
res,
4433-
OAICOMPAT_TYPE_COMPLETION);
4439+
OAICOMPAT_TYPE_CHAT);
44344440
};
44354441

44364442
const auto handle_models = [&params, &model_meta](const httplib::Request & req, httplib::Response & res) {

examples/server/utils.hpp

Lines changed: 23 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,6 @@
1515
#include <sstream>
1616
#include <random>
1717

18-
// increase max payload length to allow use of larger context size
19-
#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 1048576
20-
// increase backlog size to avoid connection resets for >> 1 slots
21-
#define CPPHTTPLIB_LISTEN_BACKLOG 512
22-
// increase max URI length to handle longer prompts in query string
23-
#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 32768
24-
// disable Nagle's algorithm
25-
#define CPPHTTPLIB_TCP_NODELAY true
26-
#include "httplib.h"
27-
2818
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
2919

3020
using json = nlohmann::ordered_json;
@@ -421,17 +411,6 @@ static json probs_vector_to_json(const llama_context * ctx, const std::vector<co
421411
return out;
422412
}
423413

424-
static bool server_sent_event(httplib::DataSink& sink, const json& data) {
425-
const std::string str =
426-
"data: " +
427-
data.dump(-1, ' ', false, json::error_handler_t::replace) +
428-
"\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row).
429-
430-
LOG_VERBOSE("data stream, to_send: %s", str.c_str());
431-
432-
return sink.write(str.c_str(), str.size());
433-
}
434-
435414
//
436415
// OAI utils
437416
//
@@ -1086,6 +1065,7 @@ struct server_tokens {
10861065
if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) {
10871066
GGML_ASSERT(has_mtmd);
10881067
const int n_pos = mtmd_input_chunk_get_n_pos(chunk);
1068+
fprintf(stdout, "n_pos: %d\n", n_pos);
10891069
llama_pos start_pos = tokens.size();
10901070
for (int i = 0; i < n_pos; ++i) {
10911071
tokens.emplace_back(LLAMA_TOKEN_NULL);
@@ -1229,54 +1209,39 @@ struct server_tokens {
12291209
}
12301210

12311211
size_t get_common_prefix(const server_tokens& b) const {
1232-
const size_t max_idx = std::min(tokens.size(), b.tokens.size());
1233-
1234-
if (!has_mtmd) {
1235-
for (size_t i = 0; i < max_idx; ++i) {
1236-
if (tokens[i] == b.tokens[i]) {
1237-
continue;
1238-
}
1239-
return i;
1240-
}
1241-
return max_idx;
1242-
}
1243-
1212+
size_t max_idx = std::min(tokens.size(), b.tokens.size());
12441213
for (size_t i = 0; i < max_idx; ++i) {
1245-
const llama_token ai = tokens[i];
1246-
const llama_token bi = b.tokens[i];
1214+
auto& ai = tokens[i];
1215+
auto& bi = b.tokens[i];
12471216

12481217
if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) {
1218+
GGML_ASSERT(has_mtmd);
12491219
const auto& a_chunk = find_chunk(i);
12501220
const auto& b_chunk = b.find_chunk(i);
1251-
12521221
GGML_ASSERT(a_chunk && b_chunk);
1253-
1254-
const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get());
1255-
const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get());
1256-
1257-
const size_t pos_a = mtmd_input_chunk_get_n_pos(a_chunk.get());
1258-
const size_t pos_b = mtmd_input_chunk_get_n_pos(b_chunk.get());
1259-
1260-
if (id_ai == id_bi && pos_a == pos_b) {
1261-
GGML_ASSERT(pos_a > 0 && "Invalid media chunk"); // should never happen
1262-
i += pos_a - 1; // will be +1 by the for loop
1222+
std::string ai_id = mtmd_input_chunk_get_id(a_chunk.get());
1223+
std::string bi_id = mtmd_input_chunk_get_id(b_chunk.get());
1224+
size_t a_pos = mtmd_input_chunk_get_n_pos(a_chunk.get());
1225+
size_t b_pos = mtmd_input_chunk_get_n_pos(b_chunk.get());
1226+
if (ai_id == bi_id && a_pos == b_pos) {
1227+
GGML_ASSERT(a_pos > 0 && "Invalid media chunk"); // should never happen
1228+
i += a_pos - 1; // will be +1 by the for loop
12631229
continue;
12641230
}
1265-
1266-
return i;
1231+
else {
1232+
return i;
1233+
}
12671234
}
1268-
1269-
if (ai == bi) {
1235+
else if (ai == bi) {
12701236
continue;
12711237
}
1272-
1273-
return i;
1238+
else {
1239+
return i;
1240+
}
12741241
}
1275-
12761242
return max_idx; // all tokens are equal
12771243
}
12781244

1279-
12801245
// make sure all text tokens are within the vocab range
12811246
bool validate(const struct llama_context* ctx) const {
12821247
const llama_model* model = llama_get_model(ctx);
@@ -1309,12 +1274,10 @@ struct server_tokens {
13091274
llama_pos n_past,
13101275
int32_t seq_id,
13111276
llama_pos& n_pos_out) {
1312-
char buffer[512];
13131277
auto& chunk = find_chunk(n_past);
13141278
const char* name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE
13151279
? "image" : "audio";
1316-
snprintf(buffer, 512, "processing : %s",name);
1317-
LOG_INFO(buffer, {});
1280+
LOG_INFO("processing %s...\n", name);
13181281
int32_t n_batch = llama_n_batch(ctx);
13191282
int64_t t0 = ggml_time_ms();
13201283
llama_pos new_n_past = n_past;
@@ -1325,11 +1288,9 @@ struct server_tokens {
13251288
n_batch,
13261289
true, // logits last
13271290
&new_n_past);
1328-
snprintf(buffer, 512, "processed in %d ms", ggml_time_ms() - t0);
1329-
LOG_INFO(buffer, {});
1291+
LOG_INFO("processed in %" PRId64 " ms\n", ggml_time_ms() - t0);
13301292
if (result != 0) {
1331-
snprintf(buffer, 512, "mtmd_helper_eval failed with status %d", result);
1332-
LOG_ERROR(buffer, {});
1293+
LOG_ERROR("mtmd_helper_eval failed with status %d", result);
13331294
n_pos_out = n_past;
13341295
return result;
13351296
}
@@ -1461,7 +1422,7 @@ static std::vector<server_tokens> tokenize_input_prompts(const llama_vocab* voca
14611422
return result;
14621423
}
14631424
// Assuming raw_buffer has .data() and .size() members
1464-
inline void print_files_info(const std::vector<raw_buffer>& files) {
1425+
inline void printFilesInfo(const std::vector<raw_buffer>& files) {
14651426
for (size_t i = 0; i < files.size(); ++i) {
14661427
const auto& file = files[i];
14671428
std::cout << "File " << i << ": Size = " << file.size() << " bytes\n";

0 commit comments

Comments
 (0)