Skip to content

Commit 9e567e3

Browse files
authored
Add an endpoint that lists all the saved prompt caches to server (#502)
1 parent 8c1d5a2 commit 9e567e3

File tree

1 file changed

+44
-1
lines changed

1 file changed

+44
-1
lines changed

examples/server/server.cpp

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3390,6 +3390,48 @@ int main(int argc, char ** argv) {
33903390
res.status = 200; // HTTP OK
33913391
};
33923392

3393+
const auto list_saved_prompts = [&ctx_server, &params](const httplib::Request& req, httplib::Response& res) {
3394+
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
3395+
json response = json::array();
3396+
namespace fs = std::filesystem;
3397+
3398+
try {
3399+
for (const auto& entry : fs::directory_iterator(params.slot_save_path)) {
3400+
if (!entry.is_regular_file() || entry.file_size() < 12) {
3401+
continue;
3402+
}
3403+
3404+
std::ifstream file(entry.path(), std::ios::binary);
3405+
if (!file) continue;
3406+
3407+
uint32_t magic, version, n_token_count;
3408+
file.read(reinterpret_cast<char*>(&magic), sizeof(magic));
3409+
file.read(reinterpret_cast<char*>(&version), sizeof(version));
3410+
file.read(reinterpret_cast<char*>(&n_token_count), sizeof(n_token_count));
3411+
3412+
if (magic != LLAMA_STATE_SEQ_MAGIC ||
3413+
version != LLAMA_STATE_SEQ_VERSION ||
3414+
entry.file_size() < (12 + (n_token_count * sizeof(llama_token)))) {
3415+
continue;
3416+
}
3417+
3418+
std::vector<llama_token> tokens(n_token_count);
3419+
file.read(reinterpret_cast<char*>(tokens.data()), tokens.size() * sizeof(llama_token));
3420+
3421+
response.push_back({
3422+
{"filename", entry.path().filename().string()},
3423+
{"filesize", entry.file_size()},
3424+
{"token_count", n_token_count},
3425+
{"prompt", tokens_to_str(ctx_server.ctx, tokens.cbegin(), tokens.cend())}
3426+
});
3427+
}
3428+
} catch (const std::exception& e) {
3429+
res.status = 500;
3430+
response = {{"error", e.what()}};
3431+
}
3432+
res.set_content(response.dump(), "application/json; charset=utf-8");
3433+
};
3434+
33933435
auto handle_static_file = [](unsigned char * content, size_t len, const char * mime_type) {
33943436
return [content, len, mime_type](const httplib::Request &, httplib::Response & res) {
33953437
res.set_content(reinterpret_cast<const char*>(content), len, mime_type);
@@ -3448,8 +3490,9 @@ int main(int argc, char ** argv) {
34483490
// Save & load slots
34493491
svr->Get ("/slots", handle_slots);
34503492
if (!params.slot_save_path.empty()) {
3451-
// only enable slot endpoints if slot_save_path is set
3493+
// these endpoints rely on slot_save_path existing
34523494
svr->Post("/slots/:id_slot", handle_slots_action);
3495+
svr->Get ("/list", list_saved_prompts);
34533496
}
34543497

34553498
//

0 commit comments

Comments
 (0)