diff --git a/llamacpp/native/CMakeLists.txt b/llamacpp/native/CMakeLists.txt index 47bed79cf..6c587b820 100644 --- a/llamacpp/native/CMakeLists.txt +++ b/llamacpp/native/CMakeLists.txt @@ -8,22 +8,43 @@ project( option(DDLLAMA_BUILD_SERVER "Build the DD llama.cpp server executable" ON) option(DDLLAMA_BUILD_UTILS "Build utilities, e.g. nv-gpu-info" OFF) -set(DDLLAMA_PATCH_COMMAND "patch" CACHE STRING "patch command") set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin) if (DDLLAMA_BUILD_SERVER) - set(LLAMA_BUILD_COMMON ON) + # Build upstream llama.cpp with server enabled + # Only set these options if they're not already defined to allow consumers to override + if(NOT DEFINED LLAMA_BUILD_COMMON) + set(LLAMA_BUILD_COMMON ON CACHE BOOL "Build common utils library") + endif() + if(NOT DEFINED LLAMA_BUILD_TOOLS) + set(LLAMA_BUILD_TOOLS ON CACHE BOOL "Build tools") + endif() + if(NOT DEFINED LLAMA_BUILD_SERVER) + set(LLAMA_BUILD_SERVER ON CACHE BOOL "Build server") + endif() add_subdirectory(vendor/llama.cpp) - # Get build info and set version for mtmd just like it's done in llama.cpp/CMakeLists.txt - include(vendor/llama.cpp/cmake/build-info.cmake) - if (NOT DEFINED LLAMA_BUILD_NUMBER) - set(LLAMA_BUILD_NUMBER ${BUILD_NUMBER}) + + # Create custom target to copy llama-server to com.docker.llama-server + if (WIN32) + set(LLAMA_SERVER_DST "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/com.docker.llama-server.exe") + else() + set(LLAMA_SERVER_DST "${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/com.docker.llama-server") endif() - set(LLAMA_INSTALL_VERSION 0.0.${LLAMA_BUILD_NUMBER}) - add_subdirectory(vendor/llama.cpp/tools/mtmd) - add_subdirectory(src/server) + + add_custom_command(OUTPUT "${LLAMA_SERVER_DST}" + COMMAND ${CMAKE_COMMAND} -E copy "$" "${LLAMA_SERVER_DST}" + DEPENDS llama-server + COMMENT "Creating com.docker.llama-server from llama-server" + ) + + add_custom_target(com.docker.llama-server ALL DEPENDS "${LLAMA_SERVER_DST}") + + # Install the renamed binary using TARGETS instead of PROGRAMS for better cross-platform support + install(TARGETS llama-server + RUNTIME DESTINATION bin + RENAME "com.docker.llama-server${CMAKE_EXECUTABLE_SUFFIX}") endif() if (WIN32 AND DDLLAMA_BUILD_UTILS) diff --git a/llamacpp/native/README.md b/llamacpp/native/README.md index ea13b15db..06caf77bc 100644 --- a/llamacpp/native/README.md +++ b/llamacpp/native/README.md @@ -1,5 +1,7 @@ # Native llama-server +This project builds the upstream llama.cpp server (`llama-server`) directly from the llama.cpp submodule and renames it to `com.docker.llama-server`. + ## Building cmake -B build @@ -15,7 +17,7 @@ This project uses llama.cpp as a git submodule located at `vendor/llama.cpp`, which points to the official llama.cpp repository at https://github.com/ggml-org/llama.cpp.git. -The project applies custom patches to llama.cpp's server implementation (`server.cpp` and `utils.hpp`) to integrate with the Docker model-runner architecture. These patches are maintained in `src/server/server.patch`. +We build the upstream `llama-server` binary directly without any modifications. ### Prerequisites @@ -45,32 +47,20 @@ If the submodule is already initialized, this command is safe to run and will en popd ``` -3. **Apply the custom llama-server patch:** +3. **Build and test:** ```bash - make -C src/server clean - make -C src/server - ``` - - This will: - - Clean the previous patched files - - Copy the new `server.cpp` and `utils.hpp` from the updated llama.cpp - - Apply our custom patches from `src/server/server.patch` - -4. **Build and test:** + # Build from the native directory + cmake -B build + cmake --build build --parallel 8 --config Release - ```bash - # Build from the native directory - cmake -B build - cmake --build build --parallel 8 --config Release - # Test the build ./build/bin/com.docker.llama-server --model ``` Make sure everything builds cleanly without errors. -5. **Commit the submodule update:** +4. **Commit the submodule update:** ```bash git add vendor/llama.cpp diff --git a/llamacpp/native/src/server/CMakeLists.txt b/llamacpp/native/src/server/CMakeLists.txt deleted file mode 100644 index 95d89b5cb..000000000 --- a/llamacpp/native/src/server/CMakeLists.txt +++ /dev/null @@ -1,31 +0,0 @@ -set(TARGET com.docker.llama-server) - -option(LLAMA_SERVER_SSL "Build SSL support for the server" OFF) - -include_directories( - ../../vendor/llama.cpp - ../../vendor/llama.cpp/common - ../../vendor/llama.cpp/include - ../../vendor/llama.cpp/ggml/include -) -include_directories(${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR}) - -if (MINGW) - # fix: https://github.com/ggml-org/llama.cpp/actions/runs/9651004652/job/26617901362?pr=8006 - add_compile_definitions(_WIN32_WINNT=${GGML_WIN_VER}) -endif() - -file(GLOB TARGET_SRCS "*.cpp") - -add_executable(${TARGET} ${TARGET_SRCS}) -install(TARGETS ${TARGET} RUNTIME) - -target_include_directories(${TARGET} PRIVATE ../../vendor/llama.cpp/tools/mtmd) -target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) -target_link_libraries(${TARGET} PRIVATE common mtmd ${CMAKE_THREAD_LIBS_INIT} cpp-httplib) - -if (WIN32) - TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) -endif() - -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/llamacpp/native/src/server/Makefile b/llamacpp/native/src/server/Makefile deleted file mode 100644 index aad5b71b0..000000000 --- a/llamacpp/native/src/server/Makefile +++ /dev/null @@ -1,18 +0,0 @@ -LLAMA_SERVER_DIR = ../../vendor/llama.cpp/tools/server/ -SERVER_FILES = server-common server-context server-http server-models server-queue server-task server -HEADERS = $(addsuffix .h, $(filter-out server, $(SERVER_FILES))) -SOURCES = $(addsuffix .cpp, $(SERVER_FILES)) - -.PHONY: clean all - -all: $(HEADERS) $(SOURCES) - -%.h: $(LLAMA_SERVER_DIR)/%.h - cp $< $@ - -%.cpp: $(LLAMA_SERVER_DIR)/%.cpp - cp $< $@ - @if [ "$@" = "server-http.cpp" ]; then patch $@ < server-http.patch; fi - -clean: - rm -f $(HEADERS) $(SOURCES) diff --git a/llamacpp/native/src/server/README.md b/llamacpp/native/src/server/README.md deleted file mode 100644 index c4c5f0499..000000000 --- a/llamacpp/native/src/server/README.md +++ /dev/null @@ -1,24 +0,0 @@ -This is a fork of the example inference server implemented as part of the -llama.cpp project. It has to be kept up to date with the vendored llama.cpp. -To do this, after bumping the llama.cpp submodule: - -1. run `make` in `native/src/server` -2. if necessary resolve any conflicts manually and update the patch file -3. commit the changes (eventually we'll get rid of this step in favour of a - fully automated workflow) - -The primary objective of this fork is to quickly add any `llama-server` changes -required by the model runner and to maintain a minimal subset of non-upstreamable -changes. Currently we've upstreamed: - -* unix socket support for mac and linux -* unix socket support for windows - -We may want to upstream: - -* making webui optional during compilation - -Changes that we don't want to upstream: - -* name change in headers returned by our `llama-server` -* support for reading the socket name from `DD_INF_UDS` diff --git a/llamacpp/native/src/server/server-common.cpp b/llamacpp/native/src/server/server-common.cpp deleted file mode 100644 index ab6b3aa7c..000000000 --- a/llamacpp/native/src/server/server-common.cpp +++ /dev/null @@ -1,1688 +0,0 @@ -#include "common.h" -#include "log.h" -#include "llama.h" -#include "mtmd.h" -#include "mtmd-helper.h" -#include "chat.h" -#include "arg.h" // for common_remote_get_content; TODO: use download.h only -#include "base64.hpp" - -#include "server-common.h" - -#include -#include -#include - -json format_error_response(const std::string & message, const enum error_type type) { - std::string type_str; - int code = 500; - switch (type) { - case ERROR_TYPE_INVALID_REQUEST: - type_str = "invalid_request_error"; - code = 400; - break; - case ERROR_TYPE_AUTHENTICATION: - type_str = "authentication_error"; - code = 401; - break; - case ERROR_TYPE_NOT_FOUND: - type_str = "not_found_error"; - code = 404; - break; - case ERROR_TYPE_SERVER: - type_str = "server_error"; - code = 500; - break; - case ERROR_TYPE_PERMISSION: - type_str = "permission_error"; - code = 403; - break; - case ERROR_TYPE_NOT_SUPPORTED: - type_str = "not_supported_error"; - code = 501; - break; - case ERROR_TYPE_UNAVAILABLE: - type_str = "unavailable_error"; - code = 503; - break; - case ERROR_TYPE_EXCEED_CONTEXT_SIZE: - type_str = "exceed_context_size_error"; - code = 400; - break; - } - return json { - {"code", code}, - {"message", message}, - {"type", type_str}, - }; -} - -// -// random string / id -// - -std::string random_string() { - static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"); - - std::random_device rd; - std::mt19937 generator(rd()); - - std::string result(32, ' '); - - for (int i = 0; i < 32; ++i) { - result[i] = str[generator() % str.size()]; - } - - return result; -} - -std::string gen_chatcmplid() { - return "chatcmpl-" + random_string(); -} - -std::string gen_tool_call_id() { - return random_string(); -} - -// -// lora utils -// - -bool lora_all_alora(const std::vector & loras) { - bool found_alora = false; - for (const auto & lora : loras) { - if (lora.scale != 0) { - if (llama_adapter_get_alora_n_invocation_tokens(lora.ptr) == 0) { - return false; - } - found_alora = true; - } - } - return found_alora; -} - -bool lora_should_clear_cache( - const std::vector & current, - const std::vector & next) { - - // This should always be called after determining that the two sets are - // _not_ equal. This assert is therefore some slightly wasted work and - // should be safe to remove as long as this method is called correctly. - GGML_ASSERT(!are_lora_equal(current, next)); - - return ( - !(lora_get_enabled_ids(current).empty() || lora_all_alora(current)) || - !lora_all_alora(next)); -} - -std::vector parse_lora_request( - const std::vector & lora_base, - const json & data) { - std::vector lora(lora_base); - int max_idx = lora.size(); - - // clear existing value - for (auto & entry : lora) { - entry.scale = 0.0f; - } - - // set value - for (const auto & entry : data) { - int id = json_value(entry, "id", -1); - float scale = json_value(entry, "scale", 0.0f); - if (0 <= id && id < max_idx) { - lora[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - return lora; -} - -bool are_lora_equal( - const std::vector & l1, - const std::vector & l2) { - if (l1.size() != l2.size()) { - return false; - } - for (size_t i = 0; i < l1.size(); ++i) { - // we don't check lora.path to reduce the time complexity - if (l1[i].scale != l2[i].scale || l1[i].ptr != l2[i].ptr) { - return false; - } - } - return true; -} - -std::vector lora_get_enabled_ids(const std::vector & loras) { - std::vector enabled_ids; - for (size_t i = 0; i < loras.size(); ++i) { - if (loras[i].scale > 0) { - enabled_ids.push_back(i); - } - } - return enabled_ids; -} - -// -// base64 utils (TODO: use the base64::decode from base64.hpp) -// - -static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; - -static inline bool is_base64(uint8_t c) { - return (isalnum(c) || (c == '+') || (c == '/')); -} - -static inline raw_buffer base64_decode(const std::string & encoded_string) { - int i = 0; - int j = 0; - int in_ = 0; - - int in_len = encoded_string.size(); - - uint8_t char_array_4[4]; - uint8_t char_array_3[3]; - - raw_buffer ret; - - while (in_len-- && (encoded_string[in_] != '=') && is_base64(encoded_string[in_])) { - char_array_4[i++] = encoded_string[in_]; in_++; - if (i == 4) { - for (i = 0; i < 4; i++) { - char_array_4[i] = base64_chars.find(char_array_4[i]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (i = 0; (i < 3); i++) { - ret.push_back(char_array_3[i]); - } - - i = 0; - } - } - - if (i) { - for (j = i; j < 4; j++) { - char_array_4[j] = 0; - } - - for (j = 0; j < 4; j++) { - char_array_4[j] = base64_chars.find(char_array_4[j]); - } - - char_array_3[0] = ((char_array_4[0] ) << 2) + ((char_array_4[1] & 0x30) >> 4); - char_array_3[1] = ((char_array_4[1] & 0xf) << 4) + ((char_array_4[2] & 0x3c) >> 2); - char_array_3[2] = ((char_array_4[2] & 0x3) << 6) + char_array_4[3]; - - for (j = 0; j < i - 1; j++) { - ret.push_back(char_array_3[j]); - } - } - - return ret; -} - -// -// server_tokens implementation -// - -server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : has_mtmd(has_mtmd) { - for (size_t i = 0; i < mtmd_chunks.size(); ++i) { - push_back(mtmd_chunks[i]); - } -} - -server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { -} - -llama_pos server_tokens::pos_next() const { - if (!has_mtmd) { - return tokens.size(); - } - - llama_pos res = tokens.size(); - - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto & chunk = it->second; - res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); - } - - return res; -} - -std::string server_tokens::str() const { - std::ostringstream oss; - oss << "tokens: "; - for (size_t idx = 0; idx < tokens.size(); ++idx) { - llama_token t = tokens[idx]; - oss << "idx:" << idx << " "; - if (t == LLAMA_TOKEN_NULL) { - oss << " "; - } else { - oss << t << " "; - } - } - oss << "\n"; - oss << "image idx: "; - for (const auto & it : map_idx_to_media) { - oss << it.first << ", "; - } - return oss.str(); -} - -const mtmd::input_chunk_ptr & server_tokens::find_chunk(size_t idx) const { - auto it = map_idx_to_media.find(idx); - if (it != map_idx_to_media.end()) { - return it->second; - } - throw std::runtime_error("Chunk not found"); -} - -void server_tokens::push_back(llama_token tok) { - if (tok == LLAMA_TOKEN_NULL) { - throw std::runtime_error("Invalid token"); - } - tokens.emplace_back(tok); -} - -void server_tokens::push_back(const mtmd_input_chunk * chunk) { - auto type = mtmd_input_chunk_get_type(chunk); - if (type == MTMD_INPUT_CHUNK_TYPE_IMAGE || type == MTMD_INPUT_CHUNK_TYPE_AUDIO) { - GGML_ASSERT(has_mtmd); - const size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); - size_t start_idx = tokens.size(); - for (size_t i = 0; i < n_tokens; ++i) { - tokens.emplace_back(LLAMA_TOKEN_NULL); - } - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx] = std::move(new_chunk); - } else if (type == MTMD_INPUT_CHUNK_TYPE_TEXT) { - size_t n_tokens; - const auto * text_tokens = mtmd_input_chunk_get_tokens_text(chunk, &n_tokens); - for (size_t i = 0; i < n_tokens; ++i) { - push_back(text_tokens[i]); - } - } else { - GGML_ABORT("Invalid chunk type"); - } -} - -void server_tokens::push_back(server_tokens & tokens) { - size_t start_idx = size(); - for (size_t i = 0; i < tokens.size(); i++) { - push_back(tokens[i]); - } - if (tokens.has_mtmd) { - // Assert if we are copying MTMD chunks to a server_tokens that does not have mtmd. - // We could also just check, but this will prevent silently dropping MTMD data. - GGML_ASSERT(has_mtmd); - for (auto it = tokens.map_idx_to_media.begin(); it != tokens.map_idx_to_media.end(); ) { - auto * chunk = tokens.map_idx_to_media[it->first].get(); - mtmd::input_chunk_ptr new_chunk(mtmd_input_chunk_copy(chunk)); - map_idx_to_media[start_idx + it->first] = std::move(new_chunk); - } - } -} - -void server_tokens::insert(const llama_tokens & inp_tokens) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens.insert(tokens.end(), inp_tokens.begin(), inp_tokens.end()); -} - -const llama_tokens & server_tokens::get_text_tokens() const { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - return tokens; -} - -void server_tokens::set_token(llama_pos pos, llama_token id) { - GGML_ASSERT(!has_mtmd); // only allow this if mtmd is disabled - tokens[pos] = id; -} - -void server_tokens::keep_first(size_t n) { - GGML_ASSERT(n <= tokens.size()); - if (has_mtmd) { - if (n == tokens.size()) { - return; // nothing to do - } - // we throw an error if we try to remove a token in the middle of an image - // for ex. with input of 5 text tokens and 2 images: - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] - // n 1 2 3 4 5 6 7 8 9 10 - // allowed to resize ^ ^ - // disallowed to resize ^ ^ ^ - if (n > 0) { - // make sure we never remove tokens in the middle of an image - // note that the case where we keep a full image at the end is allowed: - // tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] != LLAMA_TOKEN_NULL - if (tokens[n - 1] == LLAMA_TOKEN_NULL && tokens[n] == LLAMA_TOKEN_NULL) { - find_chunk(n - 1); // will throw an error if the token is not begin-of-chunk - } - } - // remove all image chunks that are not used anymore - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ) { - size_t idx = it->first; - if (idx >= n) { - it = map_idx_to_media.erase(it); - } else { - ++it; - } - } - } - tokens.resize(n); -} - -std::string server_tokens::detokenize(const llama_context * ctx, bool special) const { - llama_tokens text_tokens; - text_tokens.reserve(tokens.size()); - for (const auto & t : tokens) { - if (t != LLAMA_TOKEN_NULL) { - text_tokens.push_back(t); - } - } - return common_detokenize(ctx, text_tokens, special); -} - -size_t server_tokens::get_common_prefix(const server_tokens & b) const { - const size_t max_idx = std::min(tokens.size(), b.tokens.size()); - - if (!has_mtmd) { - for (size_t i = 0; i < max_idx; ++i) { - if (tokens[i] == b.tokens[i]) { - continue; - } - - return i; - } - - return max_idx; - } - - for (size_t i = 0; i < max_idx; ++i) { - const llama_token ai = tokens[i]; - const llama_token bi = b.tokens[i]; - - if (ai == LLAMA_TOKEN_NULL && bi == LLAMA_TOKEN_NULL) { - const auto & a_chunk = find_chunk(i); - const auto & b_chunk = b.find_chunk(i); - - GGML_ASSERT(a_chunk && b_chunk); - - const std::string id_ai = mtmd_input_chunk_get_id(a_chunk.get()); - const std::string id_bi = mtmd_input_chunk_get_id(b_chunk.get()); - - const size_t n_tok_a = mtmd_input_chunk_get_n_tokens(a_chunk.get()); - const size_t n_tok_b = mtmd_input_chunk_get_n_tokens(b_chunk.get()); - - if (id_ai == id_bi && n_tok_a == n_tok_b) { - GGML_ASSERT(n_tok_a > 0 && "Invalid media chunk"); // should never happen - i += n_tok_a - 1; // will be +1 by the for loop - continue; - } - - return i; - } - - if (ai == bi) { - continue; - } - - return i; - } - - return max_idx; // all tokens are equal -} - -bool server_tokens::validate(const struct llama_context * ctx) const { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - const int32_t n_vocab = llama_vocab_n_tokens(vocab); - - for (size_t i = 0; i < tokens.size(); ++i) { - const auto & t = tokens[i]; - if (t == LLAMA_TOKEN_NULL) { - try { - const auto & chunk = find_chunk(i); - size_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk.get()); - i += n_tokens - 1; // will be +1 by the for loop - } catch (const std::exception & e) { - return false; - } - } else if (t < 0 || t >= n_vocab) { - return false; - } - } - return true; -} - -int32_t server_tokens::process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const { - const auto & chunk = find_chunk(idx); - const char * name = mtmd_input_chunk_get_type(chunk.get()) == MTMD_INPUT_CHUNK_TYPE_IMAGE - ? "image" : "audio"; - SRV_INF("processing %s...\n", name); - int32_t n_batch = llama_n_batch(ctx); - int64_t t0 = ggml_time_ms(); - llama_pos new_n_past; // unused for now - int32_t result = mtmd_helper_eval_chunk_single(mctx, ctx, - chunk.get(), - pos, - seq_id, - n_batch, - true, // logits last - &new_n_past); - SRV_INF("%s processed in %" PRId64 " ms\n", name, ggml_time_ms() - t0); - if (result != 0) { - LOG_ERR("mtmd_helper_eval failed with status %d", result); - n_tokens_out = 0; - return result; - } - n_tokens_out = mtmd_input_chunk_get_n_tokens(chunk.get()); - return 0; -} - -server_tokens server_tokens::clone() const { - server_tokens res; - res.has_mtmd = has_mtmd; - res.tokens = tokens; - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - size_t idx = it->first; - const mtmd::input_chunk_ptr & chunk = it->second; - res.map_idx_to_media[idx] = mtmd::input_chunk_ptr(mtmd_input_chunk_copy(chunk.get())); - } - return res; -} - -// -// tokenizer and input processing utils -// - -bool json_is_array_of_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (!e.is_number_integer()) { - return false; - } - } - return true; - } - return false; -} - -bool json_is_array_of_mixed_numbers_strings(const json & data) { - bool seen_string = false; - bool seen_number = false; - if (data.is_array()) { - for (const auto & e : data) { - seen_string |= e.is_string(); - seen_number |= e.is_number_integer(); - if (seen_number && seen_string) { - return true; - } - } - } - return false; -} - -bool json_is_array_and_contains_numbers(const json & data) { - if (data.is_array()) { - for (const auto & e : data) { - if (e.is_number_integer()) { - return true; - } - } - return false; - } - return false; -} - -json json_get_nested_values(const std::vector & paths, const json & js) { - json result = json::object(); - - for (const std::string & path : paths) { - json current = js; - const auto keys = string_split(path, /*separator*/ '/'); - bool valid_path = true; - for (const std::string & k : keys) { - if (valid_path && current.is_object() && current.contains(k)) { - current = current[k]; - } else { - valid_path = false; - } - } - if (valid_path) { - result[path] = current; - } - } - return result; -} - -llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special) { - // If `add_bos` is true, we only add BOS, when json_prompt is a string, - // or the first element of the json_prompt array is a string. - llama_tokens prompt_tokens; - - if (json_prompt.is_array()) { - bool first = true; - for (const auto & p : json_prompt) { - if (p.is_string()) { - auto s = p.template get(); - - llama_tokens p; - if (first) { - p = common_tokenize(vocab, s, add_special, parse_special); - first = false; - } else { - p = common_tokenize(vocab, s, false, parse_special); - } - - prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); - } else { - if (first) { - first = false; - } - - prompt_tokens.push_back(p.template get()); - } - } - } else { - auto s = json_prompt.template get(); - prompt_tokens = common_tokenize(vocab, s, add_special, parse_special); - } - - return prompt_tokens; -} - -size_t validate_utf8(const std::string& text) { - size_t len = text.size(); - if (len == 0) return 0; - - // Check the last few bytes to see if a multi-byte character is cut off - for (size_t i = 1; i <= 4 && i <= len; ++i) { - unsigned char c = text[len - i]; - // Check for start of a multi-byte sequence from the end - if ((c & 0xE0) == 0xC0) { - // 2-byte character start: 110xxxxx - // Needs at least 2 bytes - if (i < 2) return len - i; - } else if ((c & 0xF0) == 0xE0) { - // 3-byte character start: 1110xxxx - // Needs at least 3 bytes - if (i < 3) return len - i; - } else if ((c & 0xF8) == 0xF0) { - // 4-byte character start: 11110xxx - // Needs at least 4 bytes - if (i < 4) return len - i; - } - } - - // If no cut-off multi-byte character is found, return full length - return len; -} - -// Computes FNV-1a hash of the data -static std::string fnv_hash(const uint8_t * data, size_t len) { - const uint64_t fnv_prime = 0x100000001b3ULL; - uint64_t hash = 0xcbf29ce484222325ULL; - - for (size_t i = 0; i < len; ++i) { - hash ^= data[i]; - hash *= fnv_prime; - } - return std::to_string(hash); -} - -server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files) { - mtmd::bitmaps bitmaps; - for (auto & file : files) { - mtmd::bitmap bmp(mtmd_helper_bitmap_init_from_buf(mctx, file.data(), file.size())); - if (!bmp.ptr) { - throw std::runtime_error("Failed to load image or audio file"); - } - // calculate bitmap hash (for KV caching) - std::string hash = fnv_hash(bmp.data(), bmp.n_bytes()); - bmp.set_id(hash.c_str()); - bitmaps.entries.push_back(std::move(bmp)); - } - // process prompt - std::vector inputs; - // multimodal - mtmd_input_text inp_txt = { - prompt.c_str(), - /* add_special */ true, - /* parse_special */ true, - }; - mtmd::input_chunks chunks(mtmd_input_chunks_init()); - auto bitmaps_c_ptr = bitmaps.c_ptr(); - int32_t tokenized = mtmd_tokenize(mctx, - chunks.ptr.get(), - &inp_txt, - bitmaps_c_ptr.data(), - bitmaps_c_ptr.size()); - if (tokenized != 0) { - throw std::runtime_error("Failed to tokenize prompt"); - } - auto result = server_tokens(chunks, true); - return result; -} - -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * use tokenize_input_prompts() if the input could be an array. - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - */ -static server_tokens tokenize_input_subprompt(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - constexpr char JSON_STRING_PROMPT_KEY[] = "prompt_string"; - constexpr char JSON_MTMD_DATA_KEY[] = "multimodal_data"; - const bool has_mtmd = mctx != nullptr; - if (json_prompt.is_string() || json_is_array_of_mixed_numbers_strings(json_prompt)) { - // string or mixed - llama_tokens tmp = tokenize_mixed(vocab, json_prompt, add_special, parse_special); - return server_tokens(tmp, false); - } else if (json_is_array_of_numbers(json_prompt)) { - // array of tokens - llama_tokens tmp = json_prompt.get(); - return server_tokens(tmp, false); - } else if (json_prompt.contains(JSON_STRING_PROMPT_KEY)) { - // JSON object with prompt key. - if (json_prompt.contains(JSON_MTMD_DATA_KEY)) { - if (!has_mtmd) - throw std::runtime_error("Multimodal data provided, but model does not support multimodal requests."); - - // JSON object with prompt and multimodal key. - std::vector files; - for (const auto & entry : json_prompt.at(JSON_MTMD_DATA_KEY)) { - files.push_back(base64_decode(entry)); - } - return process_mtmd_prompt(mctx, json_prompt.at(JSON_STRING_PROMPT_KEY), files); - } else { - // Not multimodal, but contains a subobject. - llama_tokens tmp = tokenize_mixed(vocab, json_prompt.at(JSON_STRING_PROMPT_KEY), add_special, parse_special); - return server_tokens(tmp, false); - } - } else { - throw std::runtime_error("\"prompt\" elements must be a string, a list of tokens, a JSON object containing a prompt string, or a list of mixed strings & tokens."); - } -} - -std::vector tokenize_input_prompts(const llama_vocab * vocab, mtmd_context * mctx, const json & json_prompt, bool add_special, bool parse_special) { - std::vector result; - if (json_prompt.is_array() && !json_is_array_and_contains_numbers(json_prompt)) { - result.reserve(json_prompt.size()); - for (const auto & p : json_prompt) { - result.push_back(tokenize_input_subprompt(vocab, mctx, p,add_special, parse_special)); - } - } else { - result.push_back(tokenize_input_subprompt(vocab, mctx, json_prompt, add_special, parse_special)); - } - if (result.empty()) { - throw std::runtime_error("\"prompt\" must not be empty"); - } - return result; -} - -// -// OAI utils -// - -// used by /completions endpoint -json oaicompat_completion_params_parse(const json & body) { - json llama_params; - - if (!body.contains("prompt")) { - throw std::runtime_error("\"prompt\" is required"); - } - - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - // Handle "echo" field - if (json_value(body, "echo", false)) { - throw std::runtime_error("Only no echo is supported"); - } - - // Params supported by OAI but unsupported by llama.cpp - static const std::vector unsupported_params { "best_of", "suffix" }; - for (const auto & param : unsupported_params) { - if (body.contains(param)) { - throw std::runtime_error("Unsupported param: " + param); - } - } - - // Copy remaining properties to llama_params - for (const auto & item : body.items()) { - // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") { - llama_params[item.key()] = item.value(); - } - } - - return llama_params; -} - -// media_path always end with '/', see arg.cpp -static void handle_media( - std::vector & out_files, - json & media_obj, - const std::string & media_path) { - std::string url = json_value(media_obj, "url", std::string()); - if (string_starts_with(url, "http")) { - // download remote image - // TODO @ngxson : maybe make these params configurable - common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); - params.max_size = 1024 * 1024 * 10; // 10MB - params.timeout = 10; // seconds - SRV_INF("downloading image from '%s'\n", url.c_str()); - auto res = common_remote_get_content(url, params); - if (200 <= res.first && res.first < 300) { - SRV_INF("downloaded %zu bytes\n", res.second.size()); - raw_buffer data; - data.insert(data.end(), res.second.begin(), res.second.end()); - out_files.push_back(data); - } else { - throw std::runtime_error("Failed to download image"); - } - - } else if (string_starts_with(url, "file://")) { - if (media_path.empty()) { - throw std::invalid_argument("file:// URLs are not allowed unless --media-path is specified"); - } - // load local image file - std::string file_path = url.substr(7); // remove "file://" - raw_buffer data; - if (!fs_validate_filename(file_path, true)) { - throw std::invalid_argument("file path is not allowed: " + file_path); - } - SRV_INF("loading image from local file '%s'\n", (media_path + file_path).c_str()); - std::ifstream file(media_path + file_path, std::ios::binary); - if (!file) { - throw std::invalid_argument("file does not exist or cannot be opened: " + file_path); - } - data.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); - out_files.push_back(data); - - } else { - // try to decode base64 image - std::vector parts = string_split(url, /*separator*/ ','); - if (parts.size() != 2) { - throw std::runtime_error("Invalid url value"); - } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid url format: " + parts[0]); - } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("url must be base64 encoded"); - } else { - auto base64_data = parts[1]; - auto decoded_data = base64_decode(base64_data); - out_files.push_back(decoded_data); - } - } -} - -// used by /chat/completions endpoint -json oaicompat_chat_params_parse( - json & body, /* openai api json semantics */ - const oaicompat_parser_options & opt, - std::vector & out_files) -{ - json llama_params; - - auto tools = json_value(body, "tools", json()); - auto has_tools = tools.is_array() && !tools.empty(); - auto stream = json_value(body, "stream", false); - auto tool_choice = json_value(body, "tool_choice", std::string("auto")); - - if (!opt.use_jinja) { - if (has_tools) { - throw std::runtime_error("tools param requires --jinja flag"); - } - if (tool_choice != "auto") { - throw std::runtime_error("tool_choice param requires --jinja flag"); - } - } - - // Handle "stop" field - if (body.contains("stop") && body.at("stop").is_string()) { - llama_params["stop"] = json::array({body.at("stop").get()}); - } else { - llama_params["stop"] = json_value(body, "stop", json::array()); - } - - auto json_schema = json_value(body, "json_schema", json()); - auto grammar = json_value(body, "grammar", std::string()); - if (!json_schema.is_null() && !grammar.empty()) { - throw std::runtime_error("Cannot use both json_schema and grammar"); - } - - // Handle "response_format" field - if (body.contains("response_format")) { - json response_format = json_value(body, "response_format", json::object()); - std::string response_type = json_value(response_format, "type", std::string()); - if (response_type == "json_object") { - json_schema = json_value(response_format, "schema", json::object()); - } else if (response_type == "json_schema") { - auto schema_wrapper = json_value(response_format, "json_schema", json::object()); - json_schema = json_value(schema_wrapper, "schema", json::object()); - } else if (!response_type.empty() && response_type != "text") { - throw std::invalid_argument("response_format type must be one of \"text\" or \"json_object\", but got: " + response_type); - } - } - - // get input files - if (!body.contains("messages")) { - throw std::invalid_argument("'messages' is required"); - } - json & messages = body.at("messages"); - if (!messages.is_array()) { - throw std::invalid_argument("Expected 'messages' to be an array"); - } - for (auto & msg : messages) { - std::string role = json_value(msg, "role", std::string()); - if (role != "assistant" && !msg.contains("content")) { - throw std::invalid_argument("All non-assistant messages must contain 'content'"); - } - if (role == "assistant") { - if (!msg.contains("content") && !msg.contains("tool_calls")) { - throw std::invalid_argument("Assistant message must contain either 'content' or 'tool_calls'!"); - } - if (!msg.contains("content")) { - continue; // avoid errors with no content - } - } - json & content = msg.at("content"); - if (content.is_string() || content.is_null()) { - continue; - } - - if (!content.is_array()) { - throw std::invalid_argument("Expected 'content' to be a string or an array"); - } - - for (auto & p : content) { - std::string type = json_value(p, "type", std::string()); - if (type == "image_url") { - if (!opt.allow_image) { - throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); - } - - json image_url = json_value(p, "image_url", json::object()); - handle_media(out_files, image_url, opt.media_path); - - // replace this chunk with a marker - p["type"] = "text"; - p["text"] = mtmd_default_marker(); - p.erase("image_url"); - - } else if (type == "input_audio") { - if (!opt.allow_audio) { - throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); - } - - json input_audio = json_value(p, "input_audio", json::object()); - std::string data = json_value(input_audio, "data", std::string()); - std::string format = json_value(input_audio, "format", std::string()); - // while we also support flac, we don't allow it here so we matches the OAI spec - if (format != "wav" && format != "mp3") { - throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'"); - } - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); - - // TODO: add audio_url support by reusing handle_media() - - // replace this chunk with a marker - p["type"] = "text"; - p["text"] = mtmd_default_marker(); - p.erase("input_audio"); - - } else if (type != "text") { - throw std::invalid_argument("unsupported content[].type"); - } - } - } - - common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(messages); - inputs.tools = common_chat_tools_parse_oaicompat(tools); - inputs.tool_choice = common_chat_tool_choice_parse_oaicompat(tool_choice); - inputs.json_schema = json_schema.is_null() ? "" : json_schema.dump(); - inputs.grammar = grammar; - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = json_value(body, "parallel_tool_calls", false); - inputs.add_generation_prompt = json_value(body, "add_generation_prompt", true); - inputs.reasoning_format = opt.reasoning_format; - if (body.contains("reasoning_format")) { - inputs.reasoning_format = common_reasoning_format_from_name(body.at("reasoning_format").get()); - } - inputs.enable_thinking = opt.enable_thinking; - if (!inputs.tools.empty() && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) { - if (body.contains("grammar")) { - throw std::invalid_argument("Cannot use custom grammar constraints with tools."); - } - llama_params["parse_tool_calls"] = true; - } - - // merge the template args provided from command line with the args provided in the user request - auto chat_template_kwargs_object = json_value(body, "chat_template_kwargs", json::object()); - inputs.chat_template_kwargs = opt.chat_template_kwargs; - for (const auto & item : chat_template_kwargs_object.items()) { - inputs.chat_template_kwargs[item.key()] = item.value().dump(); - } - - // parse the "enable_thinking" kwarg to override the default value - auto enable_thinking_kwarg = json_value(inputs.chat_template_kwargs, "enable_thinking", std::string("")); - if (enable_thinking_kwarg == "true") { - inputs.enable_thinking = true; - } else if (enable_thinking_kwarg == "false") { - inputs.enable_thinking = false; - } else if (!enable_thinking_kwarg.empty() && enable_thinking_kwarg[0] == '"') { - throw std::invalid_argument("invalid type for \"enable_thinking\" (expected boolean, got string)"); - } - - // if the assistant message appears at the end of list, we do not add end-of-turn token - // for ex. this can be useful to modify the reasoning process in reasoning models - bool prefill_assistant_message = !inputs.messages.empty() && inputs.messages.back().role == "assistant" && opt.prefill_assistant; - common_chat_msg last_message; - if (prefill_assistant_message) { - last_message = inputs.messages.back(); - inputs.messages.pop_back(); - - /* sanity check, max one assistant message at the end of the list */ - if (!inputs.messages.empty() && inputs.messages.back().role == "assistant"){ - throw std::invalid_argument("Cannot have 2 or more assistant messages at the end of the list."); - } - - /* TODO: test this properly */ - inputs.reasoning_format = COMMON_REASONING_FORMAT_NONE; - - if ( inputs.enable_thinking ) { - throw std::invalid_argument("Assistant response prefill is incompatible with enable_thinking."); - } - - inputs.add_generation_prompt = true; - } - - // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); - - /* Append assistant prefilled message */ - if (prefill_assistant_message) { - if (!last_message.content_parts.empty()) { - for (auto & p : last_message.content_parts) { - chat_params.prompt += p.text; - } - } else { - chat_params.prompt += last_message.content; - } - } - - llama_params["chat_format"] = static_cast(chat_params.format); - llama_params["prompt"] = chat_params.prompt; - if (!chat_params.grammar.empty()) { - llama_params["grammar"] = chat_params.grammar; - } - llama_params["grammar_lazy"] = chat_params.grammar_lazy; - auto grammar_triggers = json::array(); - for (const auto & trigger : chat_params.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - llama_params["grammar_triggers"] = grammar_triggers; - llama_params["preserved_tokens"] = chat_params.preserved_tokens; - llama_params["thinking_forced_open"] = chat_params.thinking_forced_open; - for (const auto & stop : chat_params.additional_stops) { - llama_params["stop"].push_back(stop); - } - if (!chat_params.parser.empty()) { - llama_params["chat_parser"] = chat_params.parser; - } - - // Handle "logprobs" field - // TODO: The response format of this option is not yet OAI-compatible, but seems like no one really using it; We may need to fix it in the future - if (json_value(body, "logprobs", false)) { - if (has_tools && stream) { - throw std::invalid_argument("logprobs is not supported with tools + stream"); - } - llama_params["n_probs"] = json_value(body, "top_logprobs", 20); - } else if (body.contains("top_logprobs") && !body.at("top_logprobs").is_null()) { - throw std::invalid_argument("top_logprobs requires logprobs to be set to true"); - } - - // Copy remaining properties to llama_params - // This allows user to use llama.cpp-specific params like "mirostat", ... via OAI endpoint. - // See "launch_slot_with_task()" for a complete list of params supported by llama.cpp - for (const auto & item : body.items()) { - // Exception: if "n_predict" is present, we overwrite the value specified earlier by "max_tokens" - if (!llama_params.contains(item.key()) || item.key() == "n_predict") { - llama_params[item.key()] = item.value(); - } - } - - return llama_params; -} - -json convert_anthropic_to_oai(const json & body) { - json oai_body; - - // Convert system prompt - json oai_messages = json::array(); - auto system_param = json_value(body, "system", json()); - if (!system_param.is_null()) { - std::string system_content; - - if (system_param.is_string()) { - system_content = system_param.get(); - } else if (system_param.is_array()) { - for (const auto & block : system_param) { - if (json_value(block, "type", std::string()) == "text") { - system_content += json_value(block, "text", std::string()); - } - } - } - - oai_messages.push_back({ - {"role", "system"}, - {"content", system_content} - }); - } - - // Convert messages - if (!body.contains("messages")) { - throw std::runtime_error("'messages' is required"); - } - const json & messages = body.at("messages"); - if (messages.is_array()) { - for (const auto & msg : messages) { - std::string role = json_value(msg, "role", std::string()); - - if (!msg.contains("content")) { - if (role == "assistant") { - continue; - } - oai_messages.push_back(msg); - continue; - } - - const json & content = msg.at("content"); - - if (content.is_string()) { - oai_messages.push_back(msg); - continue; - } - - if (!content.is_array()) { - oai_messages.push_back(msg); - continue; - } - - json tool_calls = json::array(); - json converted_content = json::array(); - json tool_results = json::array(); - bool has_tool_calls = false; - - for (const auto & block : content) { - std::string type = json_value(block, "type", std::string()); - - if (type == "text") { - converted_content.push_back(block); - } else if (type == "image") { - json source = json_value(block, "source", json::object()); - std::string source_type = json_value(source, "type", std::string()); - - if (source_type == "base64") { - std::string media_type = json_value(source, "media_type", std::string("image/jpeg")); - std::string data = json_value(source, "data", std::string()); - std::ostringstream ss; - ss << "data:" << media_type << ";base64," << data; - - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", ss.str()} - }} - }); - } else if (source_type == "url") { - std::string url = json_value(source, "url", std::string()); - converted_content.push_back({ - {"type", "image_url"}, - {"image_url", { - {"url", url} - }} - }); - } - } else if (type == "tool_use") { - tool_calls.push_back({ - {"id", json_value(block, "id", std::string())}, - {"type", "function"}, - {"function", { - {"name", json_value(block, "name", std::string())}, - {"arguments", json_value(block, "input", json::object()).dump()} - }} - }); - has_tool_calls = true; - } else if (type == "tool_result") { - std::string tool_use_id = json_value(block, "tool_use_id", std::string()); - - auto result_content = json_value(block, "content", json()); - std::string result_text; - if (result_content.is_string()) { - result_text = result_content.get(); - } else if (result_content.is_array()) { - for (const auto & c : result_content) { - if (json_value(c, "type", std::string()) == "text") { - result_text += json_value(c, "text", std::string()); - } - } - } - - tool_results.push_back({ - {"role", "tool"}, - {"tool_call_id", tool_use_id}, - {"content", result_text} - }); - } - } - - if (!converted_content.empty() || has_tool_calls) { - json new_msg = {{"role", role}}; - if (!converted_content.empty()) { - new_msg["content"] = converted_content; - } else if (has_tool_calls) { - new_msg["content"] = ""; - } - if (!tool_calls.empty()) { - new_msg["tool_calls"] = tool_calls; - } - oai_messages.push_back(new_msg); - } - - for (const auto & tool_msg : tool_results) { - oai_messages.push_back(tool_msg); - } - } - } - - oai_body["messages"] = oai_messages; - - // Convert tools - if (body.contains("tools")) { - const json & tools = body.at("tools"); - if (tools.is_array()) { - json oai_tools = json::array(); - for (const auto & tool : tools) { - oai_tools.push_back({ - {"type", "function"}, - {"function", { - {"name", json_value(tool, "name", std::string())}, - {"description", json_value(tool, "description", std::string())}, - {"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()} - }} - }); - } - oai_body["tools"] = oai_tools; - } - } - - // Convert tool_choice - if (body.contains("tool_choice")) { - const json & tc = body.at("tool_choice"); - if (tc.is_object()) { - std::string type = json_value(tc, "type", std::string()); - if (type == "auto") { - oai_body["tool_choice"] = "auto"; - } else if (type == "any" || type == "tool") { - oai_body["tool_choice"] = "required"; - } - } - } - - // Convert stop_sequences to stop - if (body.contains("stop_sequences")) { - oai_body["stop"] = body.at("stop_sequences"); - } - - // Handle max_tokens (required in Anthropic, but we're permissive) - if (body.contains("max_tokens")) { - oai_body["max_tokens"] = body.at("max_tokens"); - } else { - oai_body["max_tokens"] = 4096; - } - - // Pass through common params - for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) { - if (body.contains(key)) { - oai_body[key] = body.at(key); - } - } - - // Handle Anthropic-specific thinking param - if (body.contains("thinking")) { - json thinking = json_value(body, "thinking", json::object()); - std::string thinking_type = json_value(thinking, "type", std::string()); - if (thinking_type == "enabled") { - int budget_tokens = json_value(thinking, "budget_tokens", 10000); - oai_body["thinking_budget_tokens"] = budget_tokens; - } - } - - // Handle Anthropic-specific metadata param - if (body.contains("metadata")) { - json metadata = json_value(body, "metadata", json::object()); - std::string user_id = json_value(metadata, "user_id", std::string()); - if (!user_id.empty()) { - oai_body["__metadata_user_id"] = user_id; - } - } - - return oai_body; -} - -json format_embeddings_response_oaicompat( - const json & request, - const std::string & model_name, - const json & embeddings, - bool use_base64) { - json data = json::array(); - int32_t n_tokens = 0; - int i = 0; - for (const auto & elem : embeddings) { - json embedding_obj; - - if (use_base64) { - const auto& vec = json_value(elem, "embedding", json::array()).get>(); - const char* data_ptr = reinterpret_cast(vec.data()); - size_t data_size = vec.size() * sizeof(float); - embedding_obj = { - {"embedding", base64::encode(data_ptr, data_size)}, - {"index", i++}, - {"object", "embedding"}, - {"encoding_format", "base64"} - }; - } else { - embedding_obj = { - {"embedding", json_value(elem, "embedding", json::array())}, - {"index", i++}, - {"object", "embedding"} - }; - } - data.push_back(embedding_obj); - - n_tokens += json_value(elem, "tokens_evaluated", 0); - } - - json res = json { - {"model", json_value(request, "model", model_name)}, - {"object", "list"}, - {"usage", json { - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"data", data} - }; - - return res; -} - -json format_response_rerank( - const json & request, - const std::string & model_name, - const json & ranks, - bool is_tei_format, - std::vector & texts, - int top_n) { - int32_t n_tokens = 0; - bool return_text = is_tei_format && json_value(request, "return_text", false); - std::vector elements; // Temporary vector to hold unsorted elements - std::string score_label = is_tei_format ? "score" : "relevance_score"; - for (const auto & rank : ranks) { - int index = json_value(rank, "index", 0); - json elem = json{ - {"index", index}, - {score_label, json_value(rank, "score", 0.0)}, - }; - n_tokens += json_value(rank, "tokens_evaluated", 0); - if (return_text) { - elem["text"] = std::move(texts[index]); - } - elements.push_back(elem); - } - - std::sort(elements.begin(), elements.end(), [score_label](const json& a, const json& b) { - return json_value(a, score_label, 0.0) > json_value(b, score_label, 0.0); - }); - - elements.resize(std::min(top_n, (int)elements.size())); - json results = elements; - - if (is_tei_format) return results; - - json res = json{ - {"model", json_value(request, "model", model_name)}, - {"object", "list"}, - {"usage", json{ - {"prompt_tokens", n_tokens}, - {"total_tokens", n_tokens} - }}, - {"results", results} - }; - - return res; -} - - -// -// other utils -// - -std::vector get_token_probabilities(llama_context * ctx, int idx) { - std::vector cur; - const auto * logits = llama_get_logits_ith(ctx, idx); - - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - const int n_vocab = llama_vocab_n_tokens(vocab); - - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; - } - - // sort tokens by logits - std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); - - // apply softmax - float max_l = cur[0].logit; - float cum_sum = 0.0f; - for (size_t i = 0; i < cur.size(); ++i) { - float p = expf(cur[i].logit - max_l); - cur[i].p = p; - cum_sum += p; - } - for (size_t i = 0; i < cur.size(); ++i) { - cur[i].p /= cum_sum; - } - - return cur; -} - -std::string safe_json_to_str(const json & data) { - return data.dump(-1, ' ', false, json::error_handler_t::replace); -} - -// TODO: reuse llama_detokenize -template -static std::string tokens_to_str(llama_context * ctx, Iter begin, Iter end) { - std::string ret; - for (; begin != end; ++begin) { - ret += common_token_to_piece(ctx, *begin); - } - - return ret; -} - -std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens) { - return tokens_to_str(ctx, tokens.begin(), tokens.end()); -} - -// format incomplete utf-8 multibyte character for output -std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token) { - std::string out = token == LLAMA_TOKEN_NULL ? "" : common_token_to_piece(ctx, token); - - // if the size is 1 and first bit is 1, meaning it's a partial character - // (size > 1 meaning it's already a known token) - if (out.size() == 1 && (out[0] & 0x80) == 0x80) { - std::stringstream ss; - ss << std::hex << (out[0] & 0xff); - std::string res(ss.str()); - out = "byte: \\x" + res; - } - - return out; -} - -// format server-sent event (SSE), return the formatted string to send -// note: if data is a json array, it will be sent as multiple events, one per item -std::string format_oai_sse(const json & data) { - std::ostringstream ss; - auto send_single = [&ss](const json & data) { - ss << "data: " << - safe_json_to_str(data) << - "\n\n"; // required by RFC 8895 - A message is terminated by a blank line (two line terminators in a row). - }; - - if (data.is_array()) { - for (const auto & item : data) { - send_single(item); - } - } else { - send_single(data); - } - - return ss.str(); -} - -std::string format_anthropic_sse(const json & data) { - std::ostringstream ss; - - auto send_event = [&ss](const json & event_obj) { - if (event_obj.contains("event") && event_obj.contains("data")) { - ss << "event: " << event_obj.at("event").get() << "\n"; - ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n"; - } else { - ss << "data: " << safe_json_to_str(event_obj) << "\n\n"; - } - }; - - if (data.is_array()) { - for (const auto & event : data) { - send_event(event); - } - } else { - send_event(data); - } - - return ss.str(); -} - -bool is_valid_utf8(const std::string & str) { - const unsigned char* bytes = reinterpret_cast(str.data()); - const unsigned char* end = bytes + str.length(); - - while (bytes < end) { - if (*bytes <= 0x7F) { - // 1-byte sequence (0xxxxxxx) - bytes++; - } else if ((*bytes & 0xE0) == 0xC0) { - // 2-byte sequence (110xxxxx 10xxxxxx) - if (end - bytes < 2 || (bytes[1] & 0xC0) != 0x80) - return false; - bytes += 2; - } else if ((*bytes & 0xF0) == 0xE0) { - // 3-byte sequence (1110xxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 3 || (bytes[1] & 0xC0) != 0x80 || (bytes[2] & 0xC0) != 0x80) - return false; - bytes += 3; - } else if ((*bytes & 0xF8) == 0xF0) { - // 4-byte sequence (11110xxx 10xxxxxx 10xxxxxx 10xxxxxx) - if (end - bytes < 4 || (bytes[1] & 0xC0) != 0x80 || - (bytes[2] & 0xC0) != 0x80 || (bytes[3] & 0xC0) != 0x80) - return false; - bytes += 4; - } else { - // Invalid UTF-8 lead byte - return false; - } - } - - return true; -} - -llama_tokens format_prompt_infill( - const llama_vocab * vocab, - const json & input_prefix, - const json & input_suffix, - const json & input_extra, - const int n_batch, - const int n_predict, - const int n_ctx, - const bool spm_infill, - const llama_tokens & tokens_prompt - ) { - // TODO: optimize this block by reducing memory allocations and movement - - // use FIM repo-level pattern: - // ref: https://arxiv.org/pdf/2409.12186 - // - // [FIM_REP]myproject - // [FIM_SEP]filename0 - // extra chunk 0 - // [FIM_SEP]filename1 - // extra chunk 1 - // ... - // [FIM_SEP]filename - // [FIM_PRE]prefix[FIM_SUF]suffix[FIM_MID]prompt - // - llama_tokens extra_tokens; - extra_tokens.reserve(n_ctx); - - auto tokens_prefix = tokenize_mixed(vocab, input_prefix, false, false); - auto tokens_suffix = tokenize_mixed(vocab, input_suffix, false, false); - - if (llama_vocab_fim_rep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: make project name an input - static const auto k_fim_repo = common_tokenize(vocab, "myproject\n", false, false); - - extra_tokens.push_back(llama_vocab_fim_rep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_repo.begin(), k_fim_repo.end()); - } - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - const std::string text = json_value(chunk, "text", std::string()); - const std::string filename = json_value(chunk, "filename", std::string("tmp")); - - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - const auto k_fim_file = common_tokenize(vocab, filename + "\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); - } else { - // chunk separator in binary form to avoid confusing the AI - static const char k_chunk_prefix_str[] = {0x0a, 0x0a, 0x2d, 0x2d, 0x2d, 0x20, 0x73, 0x6e, 0x69, 0x70, 0x70, 0x65, 0x74, 0x20, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, 0x00}; - static const auto k_chunk_prefix_tokens = common_tokenize(vocab, k_chunk_prefix_str, false, false); - - extra_tokens.insert(extra_tokens.end(), k_chunk_prefix_tokens.begin(), k_chunk_prefix_tokens.end()); - } - - const auto chunk_tokens = common_tokenize(vocab, text, false, false); - extra_tokens.insert(extra_tokens.end(), chunk_tokens.begin(), chunk_tokens.end()); - } - - if (llama_vocab_fim_sep(vocab) != LLAMA_TOKEN_NULL) { - // TODO: current filename - static const auto k_fim_file = common_tokenize(vocab, "filename\n", false, false); - - extra_tokens.insert(extra_tokens.end(), llama_vocab_fim_sep(vocab)); - extra_tokens.insert(extra_tokens.end(), k_fim_file.begin(), k_fim_file.end()); - } - - // for now pick FIM context to fit in a batch (ratio prefix:suffix = 3:1, TODO: configurable?) - const int n_prefix_take = std::min(tokens_prefix.size(), 3*(n_batch/4)); - const int n_suffix_take = std::min(tokens_suffix.size(), std::max(0, (n_batch/4) - (2 + tokens_prompt.size()))); - - SRV_DBG("n_prefix_take = %d, n_suffix_take = %d, total = %d\n", n_prefix_take, n_suffix_take, (n_prefix_take + n_suffix_take)); - - // fill the rest of the context with extra chunks - const int n_extra_take = std::min(std::max(0, n_ctx - (n_batch) - 2*n_predict), extra_tokens.size()); - - tokens_prefix.erase(tokens_prefix.begin(), tokens_prefix.begin() + tokens_prefix.size() - n_prefix_take); - tokens_suffix.resize(n_suffix_take); - - tokens_prefix.insert(tokens_prefix.begin(), llama_vocab_fim_pre(vocab)); - tokens_prefix.insert(tokens_prefix.end(), tokens_prompt.begin(), tokens_prompt.end()); - tokens_suffix.insert(tokens_suffix.begin(), llama_vocab_fim_suf(vocab)); - - auto embd_inp = spm_infill ? tokens_suffix : tokens_prefix; - auto embd_end = spm_infill ? tokens_prefix : tokens_suffix; - - if (llama_vocab_get_add_bos(vocab)) { - embd_inp.insert(embd_inp.begin(), llama_vocab_bos(vocab)); - } - - SRV_DBG("extra: n_ctx = %d, n_extra_take = %d, n_extra = %d\n", n_ctx, n_extra_take, (int) extra_tokens.size()); - - // put the extra context before the FIM prefix - embd_inp.insert(embd_inp.begin(), extra_tokens.end() - n_extra_take, extra_tokens.end()); - - embd_inp.insert(embd_inp.end(), embd_end.begin(), embd_end.end()); - embd_inp.push_back(llama_vocab_fim_mid(vocab)); - - return embd_inp; -} - -server_tokens format_prompt_rerank( - const struct llama_model * model, - const struct llama_vocab * vocab, - mtmd_context * mctx, - const std::string & query, - const std::string & doc) { - server_tokens result = {}; - - const char * rerank_prompt = llama_model_chat_template(model, "rerank"); - - if (rerank_prompt != nullptr) { - std::string prompt = rerank_prompt; - string_replace_all(prompt, "{query}" , query); - string_replace_all(prompt, "{document}", doc ); - server_tokens tokens = tokenize_input_subprompt(vocab, mctx, prompt, false, true); - result.push_back(tokens); - } else { - // Get EOS token - use SEP token as fallback if EOS is not available - server_tokens query_tokens = tokenize_input_subprompt(vocab, mctx, query, false, false); - server_tokens doc_tokens = tokenize_input_subprompt(vocab, mctx, doc, false, false); - llama_token eos_token = llama_vocab_eos(vocab); - if (eos_token == LLAMA_TOKEN_NULL) { - eos_token = llama_vocab_sep(vocab); - } - - if (llama_vocab_get_add_bos(vocab)) { - result.push_back(llama_vocab_bos(vocab)); - } - result.push_back(query_tokens); - if (llama_vocab_get_add_eos(vocab)) { - result.push_back(eos_token); - } - if (llama_vocab_get_add_sep(vocab)) { - result.push_back(llama_vocab_sep(vocab)); - } - result.push_back(doc_tokens); - if (llama_vocab_get_add_eos(vocab)) { - result.push_back(eos_token); - } - } - - return result; -} diff --git a/llamacpp/native/src/server/server-common.h b/llamacpp/native/src/server/server-common.h deleted file mode 100644 index 0629bb5ed..000000000 --- a/llamacpp/native/src/server/server-common.h +++ /dev/null @@ -1,363 +0,0 @@ -#pragma once - -#include "common.h" -#include "log.h" -#include "llama.h" -#include "chat.h" -#include "mtmd.h" - -#define JSON_ASSERT GGML_ASSERT -#include - -#include -#include -#include - -const static std::string build_info("b" + std::to_string(LLAMA_BUILD_NUMBER) + "-" + LLAMA_COMMIT); - -using json = nlohmann::ordered_json; - -#define SLT_INF(slot, fmt, ...) LOG_INF("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_CNT(slot, fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) -#define SLT_WRN(slot, fmt, ...) LOG_WRN("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_ERR(slot, fmt, ...) LOG_ERR("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) -#define SLT_DBG(slot, fmt, ...) LOG_DBG("slot %12.*s: id %2d | task %d | " fmt, 12, __func__, (slot).id, ((slot).task ? (slot).task->id : -1), __VA_ARGS__) - -#define SRV_INF(fmt, ...) LOG_INF("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_CNT(fmt, ...) LOG_CNT("" fmt, __VA_ARGS__) -#define SRV_WRN(fmt, ...) LOG_WRN("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_ERR(fmt, ...) LOG_ERR("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define SRV_DBG(fmt, ...) LOG_DBG("srv %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -using raw_buffer = std::vector; - -template -static T json_value(const json & body, const std::string & key, const T & default_value) { - // Fallback null to default value - if (body.contains(key) && !body.at(key).is_null()) { - try { - return body.at(key); - } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const & err) { - LOG_WRN("Wrong type supplied for parameter '%s'. Expected '%s', using default value: %s\n", key.c_str(), json(default_value).type_name(), err.what()); - return default_value; - } - } else { - return default_value; - } -} - -// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 -enum error_type { - ERROR_TYPE_INVALID_REQUEST, - ERROR_TYPE_AUTHENTICATION, - ERROR_TYPE_SERVER, - ERROR_TYPE_NOT_FOUND, - ERROR_TYPE_PERMISSION, - ERROR_TYPE_UNAVAILABLE, // custom error - ERROR_TYPE_NOT_SUPPORTED, // custom error - ERROR_TYPE_EXCEED_CONTEXT_SIZE, // custom error -}; - -// thin wrapper around common_grammar_trigger with (de)serialization functions -struct server_grammar_trigger { - common_grammar_trigger value; - - server_grammar_trigger() = default; - server_grammar_trigger(const common_grammar_trigger & value) : value(value) {} - server_grammar_trigger(const json & in) { - value.type = (common_grammar_trigger_type) in.at("type").get(); - value.value = in.at("value").get(); - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - value.token = (llama_token) in.at("token").get(); - } - } - - json to_json() const { - json out { - {"type", (int) value.type}, - {"value", value.value}, - }; - if (value.type == COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN) { - out["token"] = (int) value.token; - } - return out; - } -}; - -json format_error_response(const std::string & message, const enum error_type type); - -// -// random string / id -// - -std::string random_string(); -std::string gen_chatcmplid(); -std::string gen_tool_call_id(); - -// -// lora utils -// - -// check whether the given lora set has only aloras activated (empty => false) -bool lora_all_alora(const std::vector & loras); - -// if the two sets of loras are different, they require a cache clear unless the -// change is only from aloras to aloras. -bool lora_should_clear_cache( - const std::vector & current, - const std::vector & next); - -std::vector parse_lora_request( - const std::vector & lora_base, - const json & data); - -bool are_lora_equal( - const std::vector & l1, - const std::vector & l2); - -// get the ids of all enabled loras -std::vector lora_get_enabled_ids(const std::vector & loras); - -// -// server_tokens -// - -/** - * server_tokens is a helper to manage the input tokens and image for the server. - * it is made this way to simplify the logic of KV cache management. - */ -struct server_tokens { - bool has_mtmd = false; - -private: // disallow accessing these members directly, risking out-of-sync - - // map a **start** index in tokens to the image chunk - // note: the order need to be in-sync with tokens - std::map map_idx_to_media; - - // list of tokens - // if the token is LLAMA_TOKEN_NULL, it indicates that this position is occupied by media chunk - // otherwise, it is a normal text token - // note: a non-text chunk can occupy multiple tokens (aka memory cells) in the token list - // note(2): for M-RoPE, an image can occupy different number of pos; do not assume 1-to-1 mapping tokens <-> pos - llama_tokens tokens; - - // for ex. with input of 5 text tokens and 2 images (each image occupies 3 tokens and 2 pos): - // [0] [1] [2] [3] [4] [img0] [img0] [img0] [img1] [img1] [img1] - // idx 0 1 2 3 4 5 6 7 8 9 10 - // pos 0 1 2 3 4 5 5 5 7 7 7 - // map_idx_to_media will contain: {5, img0}, {8, img1} - -public: - server_tokens() = default; - ~server_tokens() = default; - - // Prevent copying - // TODO: server_tokens should be copyable - remove this: - server_tokens(const server_tokens&) = delete; - server_tokens& operator=(const server_tokens&) = delete; - - // Allow moving (usually implicitly generated if members are movable) - server_tokens(server_tokens&&) = default; - server_tokens& operator=(server_tokens&&) = default; - - // Allow accessing elements using [] operator - llama_token operator[](size_t index) { return tokens[index]; } - const llama_token& operator[](size_t index) const { return tokens[index]; } - - server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd); - server_tokens(const llama_tokens & tokens, bool has_mtmd); - - // for debugging - std::string str() const; - - llama_pos pos_next() const; - const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; - - void push_back(llama_token tok); - - // will create a copy of the chunk if it contains non-text data - void push_back(const mtmd_input_chunk * chunk); - - // appends server tokens, updates the media map. copies media chunks. - void push_back(server_tokens & tokens); - - // for compatibility with context shift and prompt truncation - void insert(const llama_tokens & inp_tokens); - - // for compatibility with speculative decoding, ctx shift, slot save/load - const llama_tokens & get_text_tokens() const; - - // for compatibility with speculative decoding - void set_token(llama_pos pos, llama_token id); - - size_t size() const { return tokens.size(); } - - bool empty() const { return tokens.empty(); } - - void clear() { - map_idx_to_media.clear(); - tokens.clear(); - } - - void keep_first(size_t n); - - std::string detokenize(const llama_context * ctx, bool special) const; - - size_t get_common_prefix(const server_tokens & b) const; - - // make sure all text tokens are within the vocab range - bool validate(const struct llama_context * ctx) const; - - // encode and decode the image chunk - int32_t process_chunk( - llama_context * ctx, - mtmd_context * mctx, - size_t idx, - llama_pos pos, - int32_t seq_id, - size_t & n_tokens_out) const; - - server_tokens clone() const; -}; - - -// -// tokenizer and input processing utils -// - -bool json_is_array_of_numbers(const json & data); - -// is array having BOTH numbers & strings? -bool json_is_array_of_mixed_numbers_strings(const json & data); - -// does array have any individual integers/tokens? -bool json_is_array_and_contains_numbers(const json & data); - -// get value by path(key1 / key2) -json json_get_nested_values(const std::vector & paths, const json & js); - -/** - * this handles 2 cases: - * - only string, example: "string" - * - mixed string and tokens, example: [12, 34, "string", 56, 78] - */ -llama_tokens tokenize_mixed(const llama_vocab * vocab, const json & json_prompt, bool add_special, bool parse_special); - -// return the last index of character that can form a valid string -// if the last character is potentially cut in half, return the index before the cut -// if validate_utf8(text) == text.size(), then the whole text is valid utf8 -size_t validate_utf8(const std::string& text); - -// process mtmd prompt, return the server_tokens containing both text tokens and media chunks -server_tokens process_mtmd_prompt(mtmd_context * mctx, std::string prompt, std::vector files); - -/** - * break the input "prompt" object into multiple prompt if needed, then tokenize them - * this supports these cases: - * - "prompt": "string" - * - "prompt": [12, 34, 56] - * - "prompt": [12, 34, "string", 56, 78] - * - "prompt": { "prompt_string": "string", "multimodal_data": [ "base64" ] } - * and multiple prompts (multi-tasks): - * - "prompt": ["string1", "string2"] - * - "prompt": ["string1", [12, 34, 56]] - * - "prompt": [[12, 34, 56], [78, 90, 12]] - * - "prompt": [[12, 34, "string", 56, 78], [12, 34, 56], { "prompt_string": "string", "multimodal_data": [ "base64" ]}] - */ -std::vector tokenize_input_prompts( - const llama_vocab * vocab, - mtmd_context * mctx, - const json & json_prompt, - bool add_special, - bool parse_special); - -// -// OAI utils -// - -// used by /completions endpoint -json oaicompat_completion_params_parse(const json & body); - -struct oaicompat_parser_options { - bool use_jinja; - bool prefill_assistant; - common_reasoning_format reasoning_format; - std::map chat_template_kwargs; - common_chat_templates * tmpls; - bool allow_image; - bool allow_audio; - bool enable_thinking = true; - std::string media_path; -}; - -// used by /chat/completions endpoint -json oaicompat_chat_params_parse( - json & body, /* openai api json semantics */ - const oaicompat_parser_options & opt, - std::vector & out_files); - -// convert Anthropic Messages API format to OpenAI Chat Completions API format -json convert_anthropic_to_oai(const json & body); - -// TODO: move it to server-task.cpp -json format_embeddings_response_oaicompat( - const json & request, - const std::string & model_name, - const json & embeddings, - bool use_base64 = false); - -// TODO: move it to server-task.cpp -json format_response_rerank( - const json & request, - const std::string & model_name, - const json & ranks, - bool is_tei_format, - std::vector & texts, - int top_n); - -// -// other utils -// - -std::vector get_token_probabilities(llama_context * ctx, int idx); - -std::string safe_json_to_str(const json & data); - -std::string tokens_to_str(llama_context * ctx, const llama_tokens & tokens); - -// format incomplete utf-8 multibyte character for output -std::string tokens_to_output_formatted_string(const llama_context * ctx, const llama_token token); - -// format server-sent event (SSE), return the formatted string to send -// note: if data is a json array, it will be sent as multiple events, one per item -std::string format_oai_sse(const json & data); - -// format Anthropic-style SSE with event types -std::string format_anthropic_sse(const json & data); - -bool is_valid_utf8(const std::string & str); - -// -// formatting output responses -// TODO: move these to server-task.cpp -// - -llama_tokens format_prompt_infill( - const llama_vocab * vocab, - const json & input_prefix, - const json & input_suffix, - const json & input_extra, - const int n_batch, - const int n_predict, - const int n_ctx, - const bool spm_infill, - const llama_tokens & tokens_prompt); - -// format rerank task: [BOS]query[EOS][SEP]doc[EOS]. -server_tokens format_prompt_rerank( - const struct llama_model * model, - const struct llama_vocab * vocab, - mtmd_context * mctx, - const std::string & query, - const std::string & doc); diff --git a/llamacpp/native/src/server/server-context.cpp b/llamacpp/native/src/server/server-context.cpp deleted file mode 100644 index 5a67f508d..000000000 --- a/llamacpp/native/src/server/server-context.cpp +++ /dev/null @@ -1,3796 +0,0 @@ -#include "server-context.h" -#include "server-common.h" -#include "server-http.h" -#include "server-task.h" -#include "server-queue.h" - -#include "arg.h" -#include "common.h" -#include "llama.h" -#include "log.h" -#include "sampling.h" -#include "speculative.h" -#include "mtmd.h" -#include "mtmd-helper.h" - -#include -#include -#include -#include -#include - -// fix problem with std::min and std::max -#if defined(_WIN32) -#define WIN32_LEAN_AND_MEAN -#ifndef NOMINMAX -# define NOMINMAX -#endif -#include -#endif - -using json = nlohmann::ordered_json; - -constexpr int HTTP_POLLING_SECONDS = 1; - -// state diagram: https://github.com/ggml-org/llama.cpp/pull/9283 -enum slot_state { - SLOT_STATE_IDLE, - SLOT_STATE_WAIT_OTHER, // after assigning a task, but waiting for parent slot to process prompt - SLOT_STATE_STARTED, // after assigning a task and about to process prompt - SLOT_STATE_PROCESSING_PROMPT, - SLOT_STATE_DONE_PROMPT, - SLOT_STATE_GENERATING, -}; - -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded -}; - -static bool server_task_type_need_embd(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - return true; - default: - return false; - } -} - -static bool server_task_type_need_logits(server_task_type task_type) { - switch (task_type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - return true; - default: - return false; - } -} - -struct server_slot { - int id; - - llama_batch batch_spec = {}; - - // TODO: change to unique_ptrs for consistency: - llama_context * ctx = nullptr; - llama_context * ctx_dft = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - common_speculative * spec = nullptr; - - std::unique_ptr task; - std::unique_ptr task_prev; // used for debugging - - // used to determine the slot that has been used the longest - int64_t t_last_used = -1; - - // generation props - int32_t n_ctx = 0; // context size per slot - int32_t n_keep = 0; - int32_t n_decoded = 0; - int32_t n_remaining = -1; - int32_t i_batch = -1; - - int32_t n_prompt_tokens_cache = 0; - int32_t n_prompt_tokens_processed = 0; - - size_t last_nl_pos = 0; - - std::string generated_text; - llama_tokens generated_tokens; - - // idx of draft tokens in the main batch - // non-empty if we went to evaluate draft tokens - // ref: https://github.com/ggml-org/llama.cpp/pull/17808 - std::vector i_batch_dft; - - std::vector generated_token_probs; - - bool has_next_token = true; - bool has_new_line = false; - bool truncated = false; - - stop_type stop; - - std::string stopping_word; - - // state - slot_state state = SLOT_STATE_IDLE; - - server_prompt prompt; - - void prompt_save(server_prompt_cache & prompt_cache) const { - GGML_ASSERT(prompt.data.size() == 0); - - const size_t cur_size = llama_state_seq_get_size_ext(ctx, id, 0); - - SRV_WRN(" - saving prompt with length %d, total state size = %.3f MiB\n", - (int) prompt.tokens.size(), cur_size / (1024.0 * 1024.0)); - - auto * cur = prompt_cache.alloc(prompt, cur_size); - if (cur == nullptr) { - return; - } - - llama_state_seq_get_data_ext(ctx, cur->data.data(), cur_size, id, 0); - } - - bool prompt_load(server_prompt_cache & prompt_cache, const server_tokens & tokens) { - bool res = prompt_cache.load(prompt, tokens, ctx, id); - if (!res) { - SLT_WRN(*this, "%s", "failed to load prompt from cache\n"); - } - - return res; - } - - std::vector lora; - int32_t alora_invocation_start = -1; - - // sampling - json json_schema; - - struct common_sampler * smpl = nullptr; - - llama_token sampled; // in speculative mode, this is the last accepted token - llama_tokens drafted; - - // stats - size_t n_sent_text = 0; // number of sent text character - - int64_t t_start_process_prompt; - int64_t t_start_generation; - - double t_prompt_processing; // ms - double t_token_generation; // ms - - std::function callback_on_release; - - // Speculative decoding stats - int32_t n_draft_total = 0; // Total draft tokens generated - int32_t n_draft_accepted = 0; // Draft tokens actually accepted - - void reset() { - SLT_DBG(*this, "%s", "\n"); - - n_prompt_tokens_cache = 0; - - last_nl_pos = 0; - generated_text = ""; - has_new_line = false; - truncated = false; - stop = STOP_TYPE_NONE; - stopping_word = ""; - n_sent_text = 0; - - drafted.clear(); - i_batch_dft.clear(); - generated_tokens.clear(); - generated_token_probs.clear(); - json_schema = json(); - - // clear speculative decoding stats - n_draft_total = 0; - n_draft_accepted = 0; - - task.reset(); - task_prev.reset(); - - // clear alora start - alora_invocation_start = -1; - } - - bool need_embd() const { - GGML_ASSERT(task); - - return server_task_type_need_embd(task->type); - } - - bool need_logits() const { - GGML_ASSERT(task); - - return server_task_type_need_logits(task->type); - } - - // if the context does not have a memory module then all embeddings have to be computed within a single ubatch - // also we cannot split if the pooling would require any past tokens - bool can_split() const { - return - !need_embd() || - (llama_get_memory(ctx) && llama_pooling_type(ctx) == LLAMA_POOLING_TYPE_LAST); - } - - bool can_batch_with(server_slot & other_slot) const { - GGML_ASSERT(task); - - return task->type == other_slot.task->type && are_lora_equal(lora, other_slot.lora); - } - - bool has_budget(const common_params & global_params) { - GGML_ASSERT(task); - - if (task->params.n_predict == -1 && global_params.n_predict == -1) { - return true; // limitless - } - - n_remaining = -1; - - if (task->params.n_predict != -1) { - n_remaining = task->params.n_predict - n_decoded; - } else if (global_params.n_predict != -1) { - n_remaining = global_params.n_predict - n_decoded; - } - - return n_remaining > 0; // no budget - } - - bool is_processing() const { - return state != SLOT_STATE_IDLE; - } - - bool can_speculate() const { - return ctx_dft; - } - - void add_token(const completion_token_output & token) { - if (!is_processing()) { - SLT_WRN(*this, "%s", "slot is not processing\n"); - return; - } - generated_token_probs.push_back(token); - } - - int get_n_draft_max() const { - if (!can_speculate()) { - return 0; - } - - // determine the max draft that fits the current slot state - int n_draft_max = task->params.speculative.n_max; - - // note: slot.prompt is not yet expanded with the `id` token sampled above - // also, need to leave space for 1 extra token to allow context shifts - n_draft_max = std::min(n_draft_max, n_ctx - prompt.n_tokens() - 2); - - if (n_remaining > 0) { - n_draft_max = std::min(n_draft_max, n_remaining - 1); - } - - SLT_DBG(*this, "max possible draft: %d\n", n_draft_max); - - if (n_draft_max < task->params.speculative.n_min) { - SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min); - n_draft_max = 0; - } - return n_draft_max; - } - - // note: a slot can also be either a parent or a child - bool is_parent() const { - return is_processing() && task->n_children > 0; - } - - bool is_child() const { - return is_processing() && task->id_parent >= 0; - } - - void release() { - if (is_processing()) { - GGML_ASSERT(task); - - SLT_INF(*this, "stop processing: n_tokens = %d, truncated = %d\n", prompt.n_tokens(), truncated); - - t_last_used = ggml_time_us(); - t_token_generation = (ggml_time_us() - t_start_generation) / 1e3; - state = SLOT_STATE_IDLE; - - task_prev = std::move(task); - task.reset(); - - callback_on_release(id); - } - } - - result_timings get_timings() const { - result_timings timings; - timings.cache_n = n_prompt_tokens_cache; - - timings.prompt_n = n_prompt_tokens_processed; - timings.prompt_ms = t_prompt_processing; - timings.prompt_per_token_ms = t_prompt_processing / n_prompt_tokens_processed; - timings.prompt_per_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - timings.predicted_n = n_decoded; - timings.predicted_ms = t_token_generation; - timings.predicted_per_token_ms = t_token_generation / n_decoded; - timings.predicted_per_second = 1e3 / t_token_generation * n_decoded; - - // Add speculative metrics - if (n_draft_total > 0) { - timings.draft_n = n_draft_total; - timings.draft_n_accepted = n_draft_accepted; - } - - return timings; - } - - size_t find_stopping_strings(const std::string & text, const size_t last_token_size, bool is_full_stop) { - GGML_ASSERT(task); - - size_t stop_pos = std::string::npos; - - for (const std::string & word : task->params.antiprompt) { - size_t pos; - - if (is_full_stop) { - const size_t tmp = word.size() + last_token_size; - const size_t from_pos = text.size() > tmp ? text.size() - tmp : 0; - - pos = text.find(word, from_pos); - } else { - // otherwise, partial stop - pos = string_find_partial_stop(text, word); - } - - if (pos != std::string::npos && (stop_pos == std::string::npos || pos < stop_pos)) { - if (is_full_stop) { - stop = STOP_TYPE_WORD; - stopping_word = word; - has_next_token = false; - } - stop_pos = pos; - } - } - - return stop_pos; - } - - void print_timings() const { - const double t_prompt = t_prompt_processing / n_prompt_tokens_processed; - const double n_prompt_second = 1e3 / t_prompt_processing * n_prompt_tokens_processed; - - const double t_gen = t_token_generation / n_decoded; - const double n_gen_second = 1e3 / t_token_generation * n_decoded; - - SLT_INF(*this, - "\n" - "prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n" - " total time = %10.2f ms / %5d tokens\n", - t_prompt_processing, n_prompt_tokens_processed, t_prompt, n_prompt_second, - t_token_generation, n_decoded, t_gen, n_gen_second, - t_prompt_processing + t_token_generation, n_prompt_tokens_processed + n_decoded); - - if (n_draft_total > 0) { - const float draft_ratio = (float) n_draft_accepted / n_draft_total; - SLT_CNT(*this, - "draft acceptance rate = %0.5f (%5d accepted / %5d generated)\n", - draft_ratio, n_draft_accepted, n_draft_total - ); - } - } - - json to_json(bool only_metrics = false) const { - json res; - - res = { - {"id", id}, - {"n_ctx", n_ctx}, - {"speculative", can_speculate()}, - {"is_processing", is_processing()}, - }; - - const auto & ptask = task ? task : task_prev; - - if (ptask) { - res["id_task"] = ptask->id; - res["params"] = ptask->params.to_json(only_metrics); - res["next_token"] = { - { - {"has_next_token", has_next_token}, - {"has_new_line", has_new_line}, - {"n_remain", n_remaining}, - {"n_decoded", n_decoded}, - } - }; - - if (!only_metrics) { - res["prompt"] = ptask->tokens.detokenize(ctx, true); - res["generated"] = generated_text; - } - } - - return res; - } - - void copy_state_to(server_slot & other) const { - llama_memory_seq_rm(llama_get_memory(ctx), other.id, 0, -1); - llama_memory_seq_cp(llama_get_memory(ctx), id, other.id, 0, -1); - other.n_decoded = n_decoded; - other.n_remaining = n_remaining; - other.i_batch = i_batch; - other.n_prompt_tokens_cache = n_prompt_tokens_cache; - other.n_prompt_tokens_processed = n_prompt_tokens_processed; - other.prompt = prompt.clone(); - } -}; - - - -// -// server_metrics -// - -struct server_metrics { - int64_t t_start = 0; - - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - void init() { - t_start = ggml_time_us(); - } - - void on_prompt_eval(const server_slot & slot) { - n_prompt_tokens_processed_total += slot.n_prompt_tokens_processed; - n_prompt_tokens_processed += slot.n_prompt_tokens_processed; - t_prompt_processing += slot.t_prompt_processing; - t_prompt_processing_total += slot.t_prompt_processing; - - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - - void on_prediction(const server_slot & slot) { - n_tokens_predicted_total += slot.n_decoded; - n_tokens_predicted += slot.n_decoded; - t_tokens_generation += slot.t_token_generation; - t_tokens_generation_total += slot.t_token_generation; - } - - void on_decoded(const std::vector & slots) { - n_decode_total++; - for (const auto & slot : slots) { - if (slot.is_processing()) { - n_busy_slots_total++; - } - n_tokens_max = std::max(n_tokens_max, (uint64_t) slot.prompt.n_tokens()); - } - } - - void reset_bucket() { - n_prompt_tokens_processed = 0; - t_prompt_processing = 0; - n_tokens_predicted = 0; - t_tokens_generation = 0; - } -}; - - -// -// server_context_impl (private implementation) -// - -struct server_context_impl { - common_params params_base; - - // note: keep these alive - they determine the lifetime of the model, context, etc. - common_init_result llama_init; - common_init_result llama_init_dft; - - llama_model * model = nullptr; - llama_context * ctx = nullptr; - - // multimodal - mtmd_context * mctx = nullptr; - - const llama_vocab * vocab = nullptr; - bool vocab_dft_compatible = true; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - - llama_batch batch {}; - - bool add_bos_token = true; - - int32_t n_ctx; // total context for all clients / slots - - // slots / clients - std::vector slots; - - int slots_debug = 0; - - server_queue queue_tasks; - server_response queue_results; - - std::unique_ptr prompt_cache; - - server_metrics metrics; - - // Necessary similarity of prompt for slot selection - float slot_prompt_similarity = 0.0f; - - std::string model_name; // name of the loaded model, to be used by API - - common_chat_templates_ptr chat_templates; - oaicompat_parser_options oai_parser_opt; - - ~server_context_impl() { - mtmd_free(mctx); - - // Clear any sampling context - for (server_slot & slot : slots) { - common_sampler_free(slot.smpl); - slot.smpl = nullptr; - - llama_free(slot.ctx_dft); - slot.ctx_dft = nullptr; - - common_speculative_free(slot.spec); - slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); - } - - llama_batch_free(batch); - } - - // load the model and initialize llama_context - bool load_model(const common_params & params) { - SRV_INF("loading model '%s'\n", params.model.path.c_str()); - - params_base = params; - - llama_init = common_init_from_params(params_base); - - model = llama_init.model.get(); - ctx = llama_init.context.get(); - - if (model == nullptr) { - SRV_ERR("failed to load model, '%s'\n", params_base.model.path.c_str()); - return false; - } - - vocab = llama_model_get_vocab(model); - - n_ctx = llama_n_ctx(ctx); - - add_bos_token = llama_vocab_get_add_bos(vocab); - - if (params_base.has_speculative()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); - - auto params_dft = params_base; - - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; - params_dft.n_parallel = 1; - params_dft.cache_type_k = params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v; - - params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; - params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; - params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; - - llama_init_dft = common_init_from_params(params_dft); - - model_dft = llama_init_dft.model.get(); - - if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); - return false; - } - - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft.context.get()); - if (!vocab_dft_compatible) { - SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft.context.get()); - - cparams_dft = common_context_params_to_llama(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - // the context is not needed - we will create one for each slot - llama_init_dft.context.reset(); - } - - chat_templates = common_chat_templates_init(model, params_base.chat_template); - try { - common_chat_format_example(chat_templates.get(), params.use_jinja, params.default_template_kwargs); - } catch (const std::exception & e) { - SRV_WRN("%s: Chat template parsing error: %s\n", __func__, e.what()); - SRV_WRN("%s: The chat template that comes with this model is not yet supported, falling back to chatml. This may cause the model to output suboptimal responses\n", __func__); - chat_templates = common_chat_templates_init(model, "chatml"); - } - - std::string & mmproj_path = params_base.mmproj.path; - if (!mmproj_path.empty()) { - mtmd_helper_log_set(common_log_default_callback, nullptr); - - mtmd_context_params mparams = mtmd_context_params_default(); - mparams.use_gpu = params_base.mmproj_use_gpu; - mparams.print_timings = false; - mparams.n_threads = params_base.cpuparams.n_threads; - mparams.flash_attn_type = params_base.flash_attn_type; - mparams.warmup = params_base.warmup; - mparams.image_min_tokens = params_base.image_min_tokens; - mparams.image_max_tokens = params_base.image_max_tokens; - mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); - if (mctx == nullptr) { - SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); - return false; - } - SRV_INF("loaded multimodal model, '%s'\n", mmproj_path.c_str()); - - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by multimodal, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); - } - - if (params_base.has_speculative()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } - } - - if (!llama_memory_can_shift(llama_get_memory(ctx))) { - if (params_base.ctx_shift) { - params_base.ctx_shift = false; - SRV_WRN("%s\n", "ctx_shift is not supported by this context, it will be disabled"); - } - - if (params_base.n_cache_reuse) { - params_base.n_cache_reuse = 0; - SRV_WRN("%s\n", "cache_reuse is not supported by this context, it will be disabled"); - } - } - - return true; - } - - // initialize slots and server-related data - void init() { - // wiring up server queues - queue_tasks.on_new_task([this](server_task && task) { - process_single_task(std::move(task)); - }); - queue_tasks.on_update_slots([this]() { - update_slots(); - }); - - // Necessary similarity of prompt for slot selection - slot_prompt_similarity = params_base.slot_prompt_similarity; - - // setup slots - SRV_INF("initializing slots, n_slots = %d\n", params_base.n_parallel); - - const int n_ctx_train = llama_model_n_ctx_train(model); - - int n_ctx_slot = llama_n_ctx_seq(ctx); - if (n_ctx_slot > n_ctx_train) { - SRV_WRN("the slot context (%d) exceeds the training context of the model (%d) - capping\n", n_ctx_slot, n_ctx_train); - n_ctx_slot = n_ctx_train; - } - - for (int i = 0; i < params_base.n_parallel; i++) { - server_slot slot; - - slot.id = i; - slot.ctx = ctx; - slot.n_ctx = n_ctx_slot; - slot.mctx = mctx; - slot.prompt.tokens.has_mtmd = mctx != nullptr; - - if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); - if (slot.ctx_dft == nullptr) { - SRV_ERR("%s", "failed to create draft context\n"); - return; - } - - slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); - if (slot.spec == nullptr) { - SRV_ERR("%s", "failed to create speculator\n"); - return; - } - for (auto & pair : params_base.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); - } - } - - SLT_INF(slot, "new slot, n_ctx = %d\n", slot.n_ctx); - - slot.callback_on_release = [this](int) { - queue_tasks.pop_deferred_task(); - }; - - slot.reset(); - - slots.push_back(std::move(slot)); - } - - { - const char * LLAMA_SERVER_SLOTS_DEBUG = getenv("LLAMA_SERVER_SLOTS_DEBUG"); - slots_debug = LLAMA_SERVER_SLOTS_DEBUG ? atoi(LLAMA_SERVER_SLOTS_DEBUG) : 0; - - if (slots_debug) { - SRV_WRN("slots debug = %d\n", slots_debug); - } - } - - // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens - // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); - } - - metrics.init(); - - if (params_base.cache_ram_mib != 0) { - if (params_base.cache_ram_mib < 0) { - SRV_WRN("prompt cache is enabled, size limit: %s\n", "no limit"); - } else { - SRV_WRN("prompt cache is enabled, size limit: %d MiB\n", params_base.cache_ram_mib); - } - SRV_WRN("%s", "use `--cache-ram 0` to disable the prompt cache\n"); - - prompt_cache = std::make_unique(params_base.cache_ram_mib, n_ctx); - } else { - SRV_WRN("%s", "prompt cache is disabled - use `--cache-ram N` to enable it\n"); - } - SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); - - if (!params_base.model_alias.empty()) { - // user explicitly specified model name - model_name = params_base.model_alias; - } else if (!params_base.model.name.empty()) { - // use model name in registry format (for models in cache) - model_name = params_base.model.name; - } else { - // fallback: derive model name from file name - auto model_path = std::filesystem::path(params_base.model.path); - model_name = model_path.filename().string(); - } - - // thinking is enabled if: - // 1. It's not explicitly disabled (reasoning_budget == 0) - // 2. The chat template supports it - const bool enable_thinking = params_base.use_jinja && params_base.reasoning_budget != 0 && common_chat_templates_support_enable_thinking(chat_templates.get()); - SRV_INF("thinking = %d\n", enable_thinking); - - oai_parser_opt = { - /* use_jinja */ params_base.use_jinja, - /* prefill_assistant */ params_base.prefill_assistant, - /* reasoning_format */ params_base.reasoning_format, - /* chat_template_kwargs */ params_base.default_template_kwargs, - /* common_chat_templates */ chat_templates.get(), - /* allow_image */ mctx ? mtmd_support_vision(mctx) : false, - /* allow_audio */ mctx ? mtmd_support_audio (mctx) : false, - /* enable_thinking */ enable_thinking, - /* media_path */ params_base.media_path, - }; - - // print sample chat example to make it clear which template is used - LOG_INF("%s: chat template, chat_template: %s, example_format: '%s'\n", __func__, - common_chat_templates_source(chat_templates.get()), - common_chat_format_example(chat_templates.get(), params_base.use_jinja, params_base.default_template_kwargs).c_str()); - } - - server_slot * get_slot_by_id(int id) { - for (server_slot & slot : slots) { - if (slot.id == id) { - return &slot; - } - } - - return nullptr; - } - - server_slot * get_available_slot(const server_task & task) { - server_slot * ret = nullptr; - - bool update_cache = false; - - // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { - float sim_best = 0; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - const auto & tokens = slot.prompt.tokens; - - // skip the slot if it does not contains cached tokens - if (tokens.empty()) { - continue; - } - - // fraction of the Longest Common Prefix length with respect to the input prompt length - const float sim_cur = float(tokens.get_common_prefix(task.tokens)) / task.tokens.size(); - - // select the current slot if the criteria match - if (sim_cur > sim_best && sim_cur > slot_prompt_similarity) { - sim_best = sim_cur; - - ret = &slot; - } - } - - if (ret != nullptr) { - const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - - SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - sim_best, slot_prompt_similarity, f_keep); - - // if we are about to lose a large portion of the existing context - save it in the prompt cache - if (f_keep < 0.5f) { - update_cache = true; - } - } - } - - // find the slot that has been least recently used - if (ret == nullptr) { - int64_t t_last = -1; - - for (server_slot & slot : slots) { - // skip the slot if it is not available - if (slot.is_processing()) { - continue; - } - - // select the current slot if the criteria match - if (!ret || slot.t_last_used <= t_last) { - t_last = slot.t_last_used; - ret = &slot; - } - } - - if (ret != nullptr) { - SLT_INF(*ret, "selected slot by LRU, t_last = %" PRId64 "\n", t_last); - - update_cache = true; - } - } - - if (ret) { - const auto & tokens = ret->prompt.tokens; - - update_cache = update_cache && prompt_cache; - - // cache prompts only for completion tasks - update_cache = update_cache && task.type == SERVER_TASK_TYPE_COMPLETION; - - // don't update the cache if the slot's context is empty - update_cache = update_cache && tokens.size() > 0; - - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - - if (update_cache) { - SRV_WRN("%s", "updating prompt cache\n"); - - const int64_t t_start = ggml_time_us(); - - ret->prompt_save(*prompt_cache); - - if (!ret->prompt_load(*prompt_cache, task.tokens)) { - clear_slot(*ret); - } - - prompt_cache->update(); - - SRV_WRN("prompt cache update took %.2f ms\n", (ggml_time_us() - t_start) / 1000.0); - } - } - - return ret; - } - - void clear_slot(server_slot & slot) const { - GGML_ASSERT(!slot.is_processing()); - - SLT_WRN(slot, "clearing slot with %zu tokens\n", slot.prompt.tokens.size()); - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, -1, -1); - slot.prompt.tokens.clear(); - } - - // return true if at least one slot has been cleared - // TODO: improve logic - // - smarter decision which slot to clear (LRU or longest prompt?) - // - move slot to level 2 cache instead of removing? - // - instead of purging, try to store and resume later? - bool try_clear_idle_slots() { - bool res = false; - - if (!params_base.kv_unified) { - return res; - } - - for (auto & slot : slots) { - if (slot.is_processing()) { - continue; - } - - if (slot.prompt.n_tokens() > 0) { - SRV_WRN("purging slot %d with %zu tokens\n", slot.id, slot.prompt.tokens.size()); - - clear_slot(slot); - - res = true; - - // clear slots one by one - break; - } - } - - return res; - } - - bool launch_slot_with_task(server_slot & slot, server_task && task) { - slot.reset(); - - if (!are_lora_equal(task.params.lora, slot.lora)) { - // if lora has changed, check to see if the cache should be cleared - if (lora_should_clear_cache(slot.lora, task.params.lora)) { - SLT_INF(slot, "clearing cache for lora change. %zu loras -> %zu loras\n", slot.lora.size(), task.params.lora.size()); - slot.prompt.tokens.clear(); - } else { - SLT_INF(slot, "keeping cache for alora. %zu target loras\n", task.params.lora.size()); - } - slot.lora = task.params.lora; - } - - // if using alora, make sure it's only a single one requested and active - size_t alora_invocation_start = task.tokens.size(); - if (lora_all_alora(slot.lora)) { - const auto & enabled_ids = lora_get_enabled_ids(slot.lora); - // TODO: This will error out if a user requests two aloras, but only - // provides the activation string for one. We could, instead search - // for all requested alora activation strings and then either keep - // only the last one, or reject if multiple are found. - if (enabled_ids.size() != 1) { - send_error(task, "Cannot run multiple aLoRAs in a single request", ERROR_TYPE_INVALID_REQUEST); - return false; - } - const auto & lora = slot.lora[enabled_ids[0]].ptr; - - // get the pointer and count for the invocation tokens - const uint64_t n_invocation_tokens = llama_adapter_get_alora_n_invocation_tokens(lora); - const llama_token * invocation_tokens = llama_adapter_get_alora_invocation_tokens (lora); - - // scan backwards through the prompt tokens to find the last - // occurrence of the invocation sequence - int match_idx = static_cast(n_invocation_tokens) - 1; - for (int i = task.tokens.size() - 1; i >= 0; --i) { - // the token in this position matches the next token to find in - // the invocation sequence - if (task.tokens[i] == invocation_tokens[match_idx]) { - // if it's a full match, we've found the start - if (match_idx == 0) { - alora_invocation_start = i; - break; - } - // otherwise, check the next token in the sequence - --match_idx; - } else { - // no match in this position, so start looking over again - match_idx = static_cast(n_invocation_tokens) - 1; - } - } - - // if the activation string is not found, disable the alora - if (alora_invocation_start == task.tokens.size()) { - SLT_DBG(slot, "alora %zu requested, but not found. deactivating\n", enabled_ids[0]); - slot.lora[enabled_ids[0]].scale = 0.0f; - } else { - SLT_DBG(slot, "alora %zu activated starting at %zu\n", enabled_ids[0], alora_invocation_start); - slot.alora_invocation_start = alora_invocation_start; - } - } - - if (!task.tokens.validate(ctx)) { - send_error(task, "Prompt contains invalid tokens", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); - - // initialize samplers - { - if (slot.smpl != nullptr) { - common_sampler_free(slot.smpl); - } - - slot.smpl = common_sampler_init(model, task.params.sampling); - if (slot.smpl == nullptr) { - // for now, the only error that may happen here is invalid grammar - send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST); - return false; - } - - SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl).c_str()); - } - - // initialize draft batch - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); - } - - slot.task = std::make_unique(std::move(task)); - - slot.state = slot.is_child() - ? SLOT_STATE_WAIT_OTHER // wait for the parent to process prompt - : SLOT_STATE_STARTED; - - SLT_INF(slot, "%s", "processing task\n"); - - return true; - } - - bool process_token(completion_token_output & result, server_slot & slot) { - // remember which tokens were sampled - used for repetition penalties during sampling - const std::string token_str = result.text_to_send; - slot.sampled = result.tok; - - slot.generated_text += token_str; - if (slot.task->params.return_tokens) { - slot.generated_tokens.push_back(result.tok); - } - slot.has_next_token = true; - - // check if there is incomplete UTF-8 character at the end - bool incomplete = validate_utf8(slot.generated_text) < slot.generated_text.size(); - - // search stop word and delete it - if (!incomplete) { - size_t pos = std::min(slot.n_sent_text, slot.generated_text.size()); - - const std::string str_test = slot.generated_text.substr(pos); - bool send_text = true; - - size_t stop_pos = slot.find_stopping_strings(str_test, token_str.size(), true); - if (stop_pos != std::string::npos) { - slot.generated_text.erase( - slot.generated_text.begin() + pos + stop_pos, - slot.generated_text.end()); - pos = std::min(slot.n_sent_text, slot.generated_text.size()); - } else if (slot.has_next_token && !llama_vocab_is_eog(vocab, result.tok) ) { - stop_pos = slot.find_stopping_strings(str_test, token_str.size(), false); - send_text = stop_pos == std::string::npos; - } - - // check if there is any token to predict - if (send_text) { - // no send the stop word in the response - result.text_to_send = slot.generated_text.substr(pos, std::string::npos); - slot.n_sent_text += result.text_to_send.size(); - // add the token to slot queue and cache - } else { - result.text_to_send = ""; - } - - slot.add_token(result); - if (slot.task->params.stream) { - send_partial_response(slot, result, false); - } - } - - if (incomplete) { - slot.has_next_token = true; - } - - // if context shifting is disabled, make sure that we don't run out of context - if (!params_base.ctx_shift && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - slot.truncated = true; - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped due to running out of context capacity, prompt.n_tokens() = %d, task.n_tokens = %d, n_decoded = %d, n_ctx = %d\n", - slot.prompt.n_tokens(), slot.task->n_tokens(), slot.n_decoded, slot.n_ctx); - } - - // check the limits - if (slot.n_decoded > 0 && slot.has_next_token && !slot.has_budget(params_base)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by limit, n_decoded = %d, n_predict = %d\n", slot.n_decoded, slot.task->params.n_predict); - } - - if (slot.has_new_line) { - // require that each new line has a whitespace prefix (i.e. indentation) of at least slot.params.n_indent - if (slot.task->params.n_indent > 0) { - // check the current indentation - // TODO: improve by not doing it more than once for each new line - if (slot.last_nl_pos > 0) { - size_t pos = slot.last_nl_pos; - - int n_indent = 0; - while (pos < slot.generated_text.size() && (slot.generated_text[pos] == ' ' || slot.generated_text[pos] == '\t')) { - n_indent++; - pos++; - } - - if (pos < slot.generated_text.size() && n_indent < slot.task->params.n_indent) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - // cut the last line - slot.generated_text.erase(pos, std::string::npos); - - SLT_DBG(slot, "stopped by indentation limit, n_decoded = %d, n_indent = %d\n", slot.n_decoded, n_indent); - } - } - - // find the next new line - { - const size_t pos = slot.generated_text.find('\n', slot.last_nl_pos); - - if (pos != std::string::npos) { - slot.last_nl_pos = pos + 1; - } - } - } - } - - // check if there is a new line in the generated text - if (result.text_to_send.find('\n') != std::string::npos) { - slot.has_new_line = true; - - // if we have seen a new line, we stop after a certain time limit, but only upon another new line - if (slot.task->params.t_max_predict_ms > 0 && (ggml_time_us() - slot.t_start_generation > 1000.0f*slot.task->params.t_max_predict_ms)) { - slot.stop = STOP_TYPE_LIMIT; - slot.has_next_token = false; - - SLT_DBG(slot, "stopped by time limit, n_decoded = %d, t_max_predict_ms = %d ms\n", slot.n_decoded, (int) slot.task->params.t_max_predict_ms); - } - } - - if (llama_vocab_is_eog(vocab, result.tok)) { - slot.stop = STOP_TYPE_EOS; - slot.has_next_token = false; - - SLT_DBG(slot, "%s", "stopped by EOS\n"); - } - - SLT_DBG(slot, "n_decoded = %d, n_remaining = %d, next token: %5d '%s'\n", slot.n_decoded, slot.n_remaining, result.tok, token_str.c_str()); - - return slot.has_next_token; // continue - } - - void populate_token_probs(const server_slot & slot, completion_token_output & result, bool post_sampling, bool special, int idx) const { - size_t n_probs = slot.task->params.sampling.n_probs; - size_t n_vocab = llama_vocab_n_tokens(vocab); - - if (post_sampling) { - const auto * cur_p = common_sampler_get_candidates(slot.smpl, true); - const size_t max_probs = cur_p->size; - - // set probability for sampled token - for (size_t i = 0; i < max_probs; i++) { - if (cur_p->data[i].id == result.tok) { - result.prob = cur_p->data[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(max_probs); - for (size_t i = 0; i < std::min(max_probs, n_probs); i++) { - result.probs.push_back({ - cur_p->data[i].id, - common_token_to_piece(ctx, cur_p->data[i].id, special), - cur_p->data[i].p - }); - } - } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx, idx); - - // set probability for sampled token - for (size_t i = 0; i < n_vocab; i++) { - // set probability for sampled token - if (cur[i].id == result.tok) { - result.prob = cur[i].p; - break; - } - } - - // set probability for top n_probs tokens - result.probs.reserve(n_probs); - for (size_t i = 0; i < std::min(n_vocab, n_probs); i++) { - result.probs.push_back({ - cur[i].id, - common_token_to_piece(ctx, cur[i].id, special), - cur[i].p - }); - } - } - } - - void send_error(const server_task & task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(task.id, error, type); - } - - void send_error(const server_slot & slot, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER) { - send_error(slot.task->id, error, type, slot.task->n_tokens(), slot.n_ctx); - } - - void send_error(const int id_task, const std::string & error, const enum error_type type = ERROR_TYPE_SERVER, const int32_t n_prompt_tokens = 0, const int32_t n_ctx = 0) { - SRV_ERR("task id = %d, error: %s\n", id_task, error.c_str()); - - if (type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - GGML_ASSERT(n_ctx > 0 && n_prompt_tokens > 0); - } - - auto res = std::make_unique(); - res->id = id_task; - res->err_type = type; - res->err_msg = error; - res->n_prompt_tokens = n_prompt_tokens; - res->n_ctx = n_ctx; - - queue_results.send(std::move(res)); - } - - // if multimodal is enabled, send an error and return false - bool check_no_mtmd(const int id_task) { - if (mctx) { - send_error(id_task, "This feature is not supported by multimodal", ERROR_TYPE_NOT_SUPPORTED); - return false; - } - return true; - } - - void send_partial_response(server_slot & slot, const completion_token_output & tkn, bool is_progress) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->index = slot.task->index; - - if (is_progress) { - res->is_progress = true; - res->progress.total = slot.task->n_tokens(); - res->progress.cache = slot.n_prompt_tokens_cache; - res->progress.processed = slot.prompt.tokens.size(); - res->progress.time_ms = (ggml_time_us() - slot.t_start_process_prompt) / 1000; - } else { - res->content = tkn.text_to_send; - res->tokens = { tkn.tok }; - } - - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->res_type = slot.task->params.res_type; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - res->prob_output = tkn; // copy the token probs - } - - // populate timings if this is final response or timings_per_token is enabled - if (slot.stop != STOP_TYPE_NONE || slot.task->params.timings_per_token) { - res->timings = slot.get_timings(); - } - - queue_results.send(std::move(res)); - } - - void send_final_response(server_slot & slot) { - auto res = std::make_unique(); - - res->id = slot.task->id; - res->id_slot = slot.id; - - res->index = slot.task->index; - // in stream mode, content and tokens are already in last partial chunk - if (slot.task->params.stream) { - res->content = ""; - res->tokens = llama_tokens{}; - } else { - res->content = std::move(slot.generated_text); - res->tokens = std::move(slot.generated_tokens); - } - res->timings = slot.get_timings(); - res->prompt = slot.task->tokens.detokenize(ctx, true); - res->response_fields = std::move(slot.task->params.response_fields); - - res->truncated = slot.truncated; - res->n_decoded = slot.n_decoded; - res->n_prompt_tokens = slot.task->n_tokens(); - res->n_tokens_cached = slot.prompt.n_tokens(); - res->has_new_line = slot.has_new_line; - res->stopping_word = slot.stopping_word; - res->stop = slot.stop; - res->post_sampling_probs = slot.task->params.post_sampling_probs; - - res->verbose = slot.task->params.verbose; - res->stream = slot.task->params.stream; - res->include_usage = slot.task->params.include_usage; - res->res_type = slot.task->params.res_type; - res->oaicompat_model = slot.task->params.oaicompat_model; - res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id; - - // populate res.probs_output - if (slot.task->params.sampling.n_probs > 0) { - if (!slot.task->params.stream && slot.stop == STOP_TYPE_WORD) { - const llama_tokens stop_word_toks = common_tokenize(ctx, slot.stopping_word, false); - - size_t safe_offset = std::min(slot.generated_token_probs.size(), stop_word_toks.size()); - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end() - safe_offset); - } else { - res->probs_output = std::vector( - slot.generated_token_probs.begin(), - slot.generated_token_probs.end()); - } - } - - res->generation_params = slot.task->params; // copy the parameters - - queue_results.send(std::move(res)); - } - - void send_embedding(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - res->res_type = slot.task->params.res_type; - - const int n_embd = llama_model_n_embd(model); - - std::vector embd_res(n_embd, 0.0f); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = nullptr; - if (llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE) { - embd = llama_get_embeddings_ith(ctx, i); - } else { - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - } - - if (embd == nullptr) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; - } - - // normalize only when there is pooling - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); - res->embedding.push_back(embd_res); - break; - } - - res->embedding.emplace_back(embd, embd + n_embd); - } - - SLT_DBG(slot, "%s", "sending embeddings\n"); - - queue_results.send(std::move(res)); - } - - void send_rerank(const server_slot & slot, const llama_batch & batch) { - auto res = std::make_unique(); - res->id = slot.task->id; - res->index = slot.task->index; - res->n_tokens = slot.task->n_tokens(); - - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - - res->score = -1e6; - continue; - } - - res->score = embd[0]; - } - - SLT_DBG(slot, "sending rerank result, res.score = %f\n", res->score); - - queue_results.send(std::move(res)); - } - - // - // Functions to process the task - // - - // tokenize the input if it's set by CLI, return false on error - bool tokenize_cli_input(server_task & task) { - if (task.cli_input == nullptr) { - return true; // nothing to do - } - try { - auto & opt = oai_parser_opt; - common_chat_templates_inputs inputs; - inputs.messages = common_chat_msgs_parse_oaicompat(task.cli_input); - inputs.tools = {}; // TODO - inputs.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; - inputs.json_schema = ""; // TODO - inputs.grammar = ""; // TODO - inputs.use_jinja = opt.use_jinja; - inputs.parallel_tool_calls = false; - inputs.add_generation_prompt = true; - inputs.reasoning_format = opt.reasoning_format; - inputs.enable_thinking = opt.enable_thinking; - - // Apply chat template to the list of messages - auto chat_params = common_chat_templates_apply(opt.tmpls, inputs); - - // tokenize the resulting prompt - auto & prompt = chat_params.prompt; - if (mctx != nullptr) { - task.tokens = process_mtmd_prompt(mctx, prompt, task.cli_files); - } else { - task.tokens = std::move(tokenize_input_prompts(vocab, mctx, prompt, true, true)[0]); - } - task.cli_input.clear(); - task.cli_files.clear(); - } catch (const std::exception & e) { - send_error(task, std::string("Failed to format input: ") + e.what(), ERROR_TYPE_INVALID_REQUEST); - return false; - } - return true; - } - - void process_single_task(server_task && task) { - switch (task.type) { - case SERVER_TASK_TYPE_COMPLETION: - case SERVER_TASK_TYPE_INFILL: - case SERVER_TASK_TYPE_EMBEDDING: - case SERVER_TASK_TYPE_RERANK: - { - if (!tokenize_cli_input(task)) { - break; - } - - const int id_slot = task.id_slot; - - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); - - if (slot == nullptr) { - // if no slot is available, we defer this task for processing later - SRV_DBG("no slot is available, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - if (!launch_slot_with_task(*slot, std::move(task))) { - SRV_ERR("failed to launch slot with task, id_task = %d\n", task.id); - break; - } - } break; - case SERVER_TASK_TYPE_CANCEL: - { - // release slot linked with the task id - for (auto & slot : slots) { - if (slot.task && slot.task->id == task.id_target) { - slot.release(); - break; - } - } - } break; - case SERVER_TASK_TYPE_NEXT_RESPONSE: - { - // do nothing - } break; - case SERVER_TASK_TYPE_METRICS: - { - json slots_data = json::array(); - - int n_idle_slots = 0; - int n_processing_slots = 0; - - for (server_slot & slot : slots) { - json slot_data = slot.to_json(slots_debug == 0); - - if (slot.is_processing()) { - n_processing_slots++; - } else { - n_idle_slots++; - } - - slots_data.push_back(slot_data); - } - SRV_DBG("n_idle_slots = %d, n_processing_slots = %d\n", n_idle_slots, n_processing_slots); - - auto res = std::make_unique(); - res->id = task.id; - res->slots_data = std::move(slots_data); - res->n_idle_slots = n_idle_slots; - res->n_processing_slots = n_processing_slots; - res->n_tasks_deferred = queue_tasks.queue_tasks_deferred_size(); - res->t_start = metrics.t_start; - - res->n_prompt_tokens_processed_total = metrics.n_prompt_tokens_processed_total; - res->t_prompt_processing_total = metrics.t_prompt_processing_total; - res->n_tokens_predicted_total = metrics.n_tokens_predicted_total; - res->t_tokens_generation_total = metrics.t_tokens_generation_total; - - res->n_tokens_max = metrics.n_tokens_max; - - res->n_prompt_tokens_processed = metrics.n_prompt_tokens_processed; - res->t_prompt_processing = metrics.t_prompt_processing; - res->n_tokens_predicted = metrics.n_tokens_predicted; - res->t_tokens_generation = metrics.t_tokens_generation; - - res->n_decode_total = metrics.n_decode_total; - res->n_busy_slots_total = metrics.n_busy_slots_total; - - if (task.metrics_reset_bucket) { - metrics.reset_bucket(); - } - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_SAVE: - { - if (!check_no_mtmd(task.id)) { - break; - } - - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const size_t token_count = slot->prompt.tokens.size(); - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - const llama_tokens & tokens = slot->prompt.tokens.get_text_tokens(); - const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id, tokens.data(), token_count); - - const int64_t t_end = ggml_time_us(); - const double t_save_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = true; - res->n_tokens = token_count; - res->n_bytes = nwrite; - res->t_ms = t_save_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_RESTORE: - { - if (!check_no_mtmd(task.id)) break; - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - const int64_t t_start = ggml_time_us(); - - std::string filename = task.slot_action.filename; - std::string filepath = task.slot_action.filepath; - - llama_tokens tokens; - tokens.resize(slot->n_ctx); - size_t token_count = 0; - size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id, tokens.data(), tokens.size(), &token_count); - if (nread == 0) { - slot->prompt.tokens.clear(); // KV may already been invalidated? - send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); - break; - } - tokens.resize(token_count); - slot->prompt.tokens.clear(); - slot->prompt.tokens.insert(tokens); - - const int64_t t_end = ggml_time_us(); - const double t_restore_ms = (t_end - t_start) / 1000.0; - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->filename = filename; - res->is_save = false; - res->n_tokens = token_count; - res->n_bytes = nread; - res->t_ms = t_restore_ms; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SLOT_ERASE: - { - if (!check_no_mtmd(task.id)) { - break; - } - int id_slot = task.slot_action.slot_id; - server_slot * slot = get_slot_by_id(id_slot); - if (slot == nullptr) { - send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST); - break; - } - if (slot->is_processing()) { - // if requested slot is unavailable, we defer this task for processing later - SRV_DBG("requested slot is unavailable, defer task, id_task = %d\n", task.id); - queue_tasks.defer(std::move(task)); - break; - } - - // Erase token cache - const size_t n_erased = slot->prompt.tokens.size(); - - clear_slot(*slot); - - auto res = std::make_unique(); - res->id = task.id; - res->id_slot = id_slot; - res->n_erased = n_erased; - queue_results.send(std::move(res)); - } break; - case SERVER_TASK_TYPE_SET_LORA: - { - params_base.lora_adapters = std::move(task.set_lora); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; - } - } - - void update_slots() { - // check if all slots are idle - { - bool all_idle = true; - - for (auto & slot : slots) { - if (slot.is_processing()) { - all_idle = false; - break; - } - } - - if (all_idle) { - SRV_INF("%s", "all slots are idle\n"); - - return; - } - } - - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); - } - - // apply context-shift if needed - // TODO: simplify and improve - for (server_slot & slot : slots) { - if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { - if (!params_base.ctx_shift) { - // this check is redundant (for good) - // we should never get here, because generation should already stopped in process_token() - send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (mctx) { - // we should never reach this because params_base.ctx_shift is automatically disabled if mmproj is loaded - // we don't support ctx_shift because an image chunk may contains multiple tokens - GGML_ABORT("not supported by multimodal"); - } - - if (slot.is_parent() || slot.is_child()) { - send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - // Shift context - int n_keep = slot.task->params.n_keep < 0 ? slot.task->n_tokens() : slot.task->params.n_keep; - - if (add_bos_token) { - n_keep += 1; - } - - n_keep = std::min(slot.n_ctx - 4, n_keep); - - const int n_left = slot.prompt.n_tokens() - n_keep; - const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); - - SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, n_keep , n_keep + n_discard); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, n_keep + n_discard, slot.prompt.n_tokens(), -n_discard); - - // add generated tokens to cache - // ref: https://github.com/ggml-org/llama.cpp/pull/16818#discussion_r2473269481 - { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - llama_tokens new_tokens = slot.prompt.tokens.get_text_tokens(); // copy - for (size_t i = n_keep + n_discard; i < new_tokens.size(); i++) { - new_tokens[i - n_discard] = new_tokens[i]; - } - - new_tokens.resize(slot.prompt.tokens.size() - n_discard); - - slot.prompt.tokens.clear(); - slot.prompt.tokens.insert(new_tokens); - } - - slot.truncated = true; - } - } - - // start populating the batch for this iteration - common_batch_clear(batch); - - // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; - - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; - - // first, add sampled tokens from any ongoing sequences - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING) { - continue; - } - - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (!slot_batched->can_batch_with(slot)) { - continue; - } - - // generate draft tokens in speculative decoding mode - // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] - // perform the speculative drafting for all sequences at the same time in a single batch - int n_draft_max = slot.get_n_draft_max(); - if (n_draft_max > 0) { - if (mctx) { - // we should never reach this, as speculative is automatically disabled if mmproj is loaded - GGML_ABORT("not supported by multimodal"); - } - - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; - const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); - - // add the sampled token to the batch - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(slot.sampled); - - if (slot.task->params.speculative.n_min > (int) draft.size()) { - SLT_DBG(slot, "ignoring small draft: %d < %d\n", (int) draft.size(), slot.task->params.speculative.n_min); - // fallback to normal decoding - slot.i_batch = slot.i_batch_dft[0]; - slot.drafted.clear(); - slot.i_batch_dft.clear(); - } else { - // keep track of total number of drafted tokens tested - slot.n_draft_total += draft.size(); - - // add all drafted tokens to the batch - for (size_t i = 0; i < draft.size(); i++) { - slot.i_batch_dft.push_back(batch.n_tokens); - common_batch_add(batch, draft[i], slot.prompt.tokens.pos_next(), { slot.id }, true); - slot.prompt.tokens.push_back(draft[i]); - } - slot.drafted = std::move(draft); - } - } else { - // no speculative decoding - slot.i_batch = batch.n_tokens; - - common_batch_add(batch, slot.sampled, slot.prompt.tokens.pos_next(), { slot.id }, true); - - slot.prompt.tokens.push_back(slot.sampled); - - SLT_DBG(slot, "slot decode token, n_ctx = %d, n_tokens = %d, truncated = %d\n", - slot.n_ctx, slot.prompt.n_tokens(), slot.truncated); - } - } - - // process in chunks of params.n_batch - int32_t n_batch = llama_n_batch(ctx); - int32_t n_ubatch = llama_n_ubatch(ctx); - - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; - - // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { - if (!slot.is_processing()) { - continue; - } - - // check if we can batch this slot with the previous one - if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; - } - - // this slot still has a prompt to be processed - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { - const auto & input_tokens = slot.task->tokens; - - // TODO: maybe move branch to outside of this loop in the future - if (slot.state == SLOT_STATE_STARTED) { - slot.t_start_process_prompt = ggml_time_us(); - slot.t_start_generation = 0; - - slot.state = SLOT_STATE_PROCESSING_PROMPT; - - SLT_INF(slot, "new prompt, n_ctx_slot = %d, n_keep = %d, task.n_tokens = %d\n", - slot.n_ctx, slot.task->params.n_keep, slot.task->n_tokens()); - - // print prompt tokens (for debugging) - /*if (1) { - // first 16 tokens (avoid flooding logs) - for (int i = 0; i < std::min(16, input_tokens.size()); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - } else { - // all - for (int i = 0; i < (int) input_tokens.size(); i++) { - SLT_DBG(slot, "prompt token %3d: %6d '%s'\n", i, input_tokens[i], common_token_to_piece(ctx, input_tokens[i]).c_str()); - } - }*/ - - // keep track how many tokens we can reuse from the previous state - int n_past = 0; - - // empty prompt passed -> release the slot and send empty response - if (input_tokens.empty()) { - SLT_WRN(slot, "%s", "empty prompt - releasing slot\n"); - - slot.print_timings(); - send_final_response(slot); - slot.release(); - - continue; - } - - // TODO: support memory-less logits computation - if (slot.need_logits() && !llama_get_memory(ctx)) { - send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (!slot.can_split()) { - if (slot.task->n_tokens() > n_ubatch) { - send_error(slot, "input is too large to process. increase the physical batch size", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - if (slot.task->n_tokens() > slot.n_ctx) { - send_error(slot, "input is larger than the max context size. skipping", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - } else { - if (slot.task->n_tokens() >= slot.n_ctx) { - send_error(slot, "the request exceeds the available context size, try increasing it", ERROR_TYPE_EXCEED_CONTEXT_SIZE); - slot.release(); - continue; - } - - if (slot.task->params.cache_prompt) { - // reuse any previously computed tokens that are common with the new prompt - n_past = slot.prompt.tokens.get_common_prefix(input_tokens); - - // if there is an alora invoked, don't cache after the invocation start - if (slot.alora_invocation_start > 0) { - SLT_DBG(slot, "only caching to alora invocation start (n_past = %d, alora_invocation_start = %d)\n", n_past, slot.alora_invocation_start); - n_past = std::min(n_past, slot.alora_invocation_start - 1); - } - - const auto n_cache_reuse = slot.task->params.n_cache_reuse; - - const bool can_cache_reuse = - llama_memory_can_shift(llama_get_memory(ctx)) && - !slot.prompt.tokens.has_mtmd; - - if (!can_cache_reuse && n_cache_reuse > 0) { - SLT_WRN(slot, "cache reuse is not supported - ignoring n_cache_reuse = %d\n", n_cache_reuse); - } - - // reuse chunks from the cached prompt by shifting their KV cache in the new position - if (can_cache_reuse && n_cache_reuse > 0) { - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - size_t head_c = n_past; // cache - size_t head_p = n_past; // current prompt - - if (mctx) { - // we should never reach this - GGML_ABORT("not supported by multimodal"); - } - - SLT_DBG(slot, "trying to reuse chunks with size > %d, n_past = %d\n", n_cache_reuse, n_past); - - while (head_c < slot.prompt.tokens.size() && - head_p < input_tokens.size()) { - - size_t n_match = 0; - while (head_c + n_match < slot.prompt.tokens.size() && - head_p + n_match < input_tokens.size() && - slot.prompt.tokens[head_c + n_match] == input_tokens[head_p + n_match]) { - n_match++; - } - - if (n_match >= (size_t) n_cache_reuse) { - SLT_INF(slot, "reusing chunk with size %zu, shifting KV cache [%zu, %zu) -> [%zu, %zu)\n", n_match, head_c, head_c + n_match, head_p, head_p + n_match); - //for (size_t i = head_p; i < head_p + n_match; i++) { - // SLT_DBG(slot, "cache token %3zu: %6d '%s'\n", i, prompt_tokens[i], common_token_to_piece(ctx, prompt_tokens[i]).c_str()); - //} - - const int64_t kv_shift = (int64_t) head_p - (int64_t) head_c; - - llama_memory_seq_rm (llama_get_memory(ctx), slot.id, head_p, head_c); - llama_memory_seq_add(llama_get_memory(ctx), slot.id, head_c, head_c + n_match, kv_shift); - - for (size_t i = 0; i < n_match; i++) { - slot.prompt.tokens.set_token(head_p + i, slot.prompt.tokens[head_c + i]); - n_past++; - } - - head_c += n_match; - head_p += n_match; - } else { - head_c += 1; - } - } - - SLT_DBG(slot, "after context reuse, new n_past = %d\n", n_past); - } - } else { - // if we don't cache the prompt, we have to remove all previous tokens - n_past = 0; - } - - // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 - const auto n_swa = std::max(1, llama_model_n_swa(model)); - - // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, n_past - n_swa); - - // note: disallow with mtmd contexts for now - // https://github.com/ggml-org/llama.cpp/issues/17043 - if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - if (pos_min == -1) { - SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); - GGML_ABORT("pos_min == -1, but n_past > 0 - should not happen: https://github.com/ggml-org/llama.cpp/pull/13833#discussion_r2116181237"); - } - - // when the prompt prefix does not match, print the tokens around the mismatch - // this is useful for debugging prompt caching - if (slots_debug) { - const int np0 = std::max(n_past - 4, 0); - const int np1 = std::min(n_past + 6, std::min(slot.prompt.tokens.size(), slot.task->tokens.size())); - - std::stringstream ss0; - std::stringstream ss1; - - std::stringstream st0; - std::stringstream st1; - - ss0 << "old: ... "; - ss1 << "new: ... "; - - for (int i = np0; i < np1; i++) { - if (i == n_past) { - ss0 << " | "; - ss1 << " | "; - } - - { - const auto token = slot.prompt.tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss0 << piece; - st0 << std::setw(8) << token; - } - - { - const auto token = slot.task->tokens[i]; - const auto piece = token != LLAMA_TOKEN_NULL ? common_token_to_piece(ctx, token) : "[mtmd]"; - ss1 << piece; - st1 << std::setw(8) << token; - } - } - - SLT_WRN(slot, "%s\n", ss0.str().c_str()); - SLT_WRN(slot, "%s\n", ss1.str().c_str()); - - SLT_WRN(slot, "%s\n", st0.str().c_str()); - SLT_WRN(slot, "%s\n", st1.str().c_str()); - } - - if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); - - // search for a context checkpoint - const auto it = std::find_if( - slot.prompt.checkpoints.rbegin(), - slot.prompt.checkpoints.rend(), - [&](const auto & cur) { - // guarantee that a checkpoint will result in at least one token being processed [TAG_PROMPT_LOGITS] - return cur.pos_min < pos_min_thold; - } - ); - - bool do_reset = it == slot.prompt.checkpoints.rend(); - - if (!do_reset) { - // restore the context checkpoint - const size_t checkpoint_size = it->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - do_reset = true; - //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); - } else { - n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); - } - } - - if (do_reset) { - SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", - "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); - n_past = 0; - } - } - } - - { - // erase any checkpoints with pos_min > pos_min_thold - for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { - const auto & cur = *it; - if (cur.pos_min > pos_min_thold) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); - it = slot.prompt.checkpoints.erase(it); - } else { - ++it; - } - } - } - } - - // [TAG_PROMPT_LOGITS] - if (n_past == slot.task->n_tokens() && n_past > 0) { - SLT_WRN(slot, "need to evaluate at least 1 token for each active slot (n_past = %d, task.n_tokens() = %d)\n", n_past, slot.task->n_tokens()); - n_past--; - SLT_WRN(slot, "n_past was set to %d\n", n_past); - } - - slot.n_prompt_tokens_cache = n_past; - slot.n_prompt_tokens_processed = 0; - - slot.prompt.tokens.keep_first(n_past); - } - - if (!slot.can_split()) { - // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; - } - } - - // truncate any tokens that are beyond n_past for this slot - const llama_pos p0 = slot.prompt.tokens.pos_next(); - - SLT_INF(slot, "n_tokens = %d, memory_seq_rm [%d, end)\n", slot.prompt.n_tokens(), p0); - - if (!llama_memory_seq_rm(llama_get_memory(ctx), slot.id, p0, -1)) { - SLT_WRN(slot, "failed to truncate tokens with position >= %d - clearing the memory\n", p0); - - clear_slot(slot); - - // there is no common part left - slot.n_prompt_tokens_cache = 0; - } - - // check if we should process the image - if (slot.prompt.n_tokens() < slot.task->n_tokens() && input_tokens[slot.prompt.n_tokens()] == LLAMA_TOKEN_NULL) { - // process the image - size_t n_tokens_out = 0; - int32_t res = input_tokens.process_chunk(ctx, mctx, slot.prompt.n_tokens(), slot.prompt.tokens.pos_next(), slot.id, n_tokens_out); - if (res != 0) { - SLT_ERR(slot, "failed to process image, res = %d\n", res); - send_error(slot, "failed to process image", ERROR_TYPE_SERVER); - slot.release(); - continue; - } - - slot.n_prompt_tokens_processed += n_tokens_out; - - // add the image chunk to cache - { - const auto & chunk = input_tokens.find_chunk(slot.prompt.n_tokens()); - slot.prompt.tokens.push_back(chunk.get()); // copy - } - } - - // If using an alora, there may be uncached tokens that come - // before the invocation sequence. When this happens, the - // tokens before the invocation sequence need to be - // processed without the adapter in a separate batch, then - // the adapter needs to be enabled for the remaining tokens. - if (lora_all_alora(slot.lora) && slot.alora_invocation_start - 1 > slot.prompt.n_tokens()) { - SLT_DBG(slot, "processing pre-alora tokens without the adapter (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - const auto & enabled_loras = lora_get_enabled_ids(slot.lora); - GGML_ASSERT(enabled_loras.size() == 1); - alora_scale = slot.lora[enabled_loras[0]].scale; - slot.lora[enabled_loras[0]].scale = 0.0f; - alora_disabled_id = enabled_loras[0]; - } - - bool do_checkpoint = params_base.n_ctx_checkpoints > 0; - - // make checkpoints only for completion tasks - do_checkpoint = do_checkpoint && slot.task->type == SERVER_TASK_TYPE_COMPLETION; - - // make a checkpoint of the parts of the memory that cannot be rolled back. - // checkpoints are created only if: - // - the model uses SWA and we are not using `swa_full` - // - the model architecture is marked as recurrent or hybrid - // - // TODO: try to make this conditional on the context or the memory module, instead of the model type - do_checkpoint = do_checkpoint && ( - llama_model_is_recurrent(model) || - llama_model_is_hybrid(model) || - (llama_model_n_swa(model) > 0 && !params_base.swa_full) - ); - - // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { - // get next token to process - llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; - if (cur_tok == LLAMA_TOKEN_NULL) { - break; // end of text chunk - } - - // if this is an alora request with pre-invocation - // tokens that are not cached, we need to stop filling - // this batch at those pre-invocation tokens. - if (alora_scale > 0 && slot.prompt.n_tokens() == slot.alora_invocation_start - 1) { - SLT_DBG(slot, "stop prompt batch filling at (n_tokens = %d, alora_invocation_start = %d)\n", slot.prompt.n_tokens(), slot.alora_invocation_start); - break; - } - - // embedding requires all tokens in the batch to be output - common_batch_add(batch, - cur_tok, - slot.prompt.tokens.pos_next(), - { slot.id }, - slot.need_embd()); - slot.prompt.tokens.push_back(cur_tok); - - slot.n_prompt_tokens_processed++; - - // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { - break; - } - } - - // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); - - // entire prompt has been processed - if (slot.prompt.n_tokens() == slot.task->n_tokens()) { - slot.state = SLOT_STATE_DONE_PROMPT; - - GGML_ASSERT(batch.n_tokens > 0); - - common_sampler_reset(slot.smpl); - - // Process all prompt tokens through sampler system - for (int i = 0; i < slot.task->n_tokens(); ++i) { - llama_token id = input_tokens[i]; - if (id != LLAMA_TOKEN_NULL) { - common_sampler_accept(slot.smpl, id, false); - } - } - - // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; - - slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; - - SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - - const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); - const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx), slot.id); - - // no need for empty or small checkpoints - do_checkpoint = do_checkpoint && (pos_min >= 0 && pos_max >= 64); - - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); - - if (do_checkpoint) { - while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { - // make room for the new checkpoint, if needed - const auto & cur = slot.prompt.checkpoints.front(); - - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - - slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); - } - - const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), - }); - - llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); - } - } - } - - if (!slot_batched) { - slot_batched = &slot; - } - - if (batch.n_tokens >= n_batch) { - break; - } - } - } - - if (batch.n_tokens == 0) { - SRV_WRN("%s", "no tokens to decode\n"); - return; - } - - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - - if (slot_batched) { - // apply lora, only need to do it once per batch - common_set_adapter_lora(ctx, slot_batched->lora); - - // if the lora is temporarily disabled for an alora, re-enable it - // for next time - if (alora_scale > 0.0f) { - SRV_DBG("re-enabling alora with scale %f\n", alora_scale); - slot_batched->lora[alora_disabled_id].scale = alora_scale; - } - - llama_set_embeddings(ctx, slot_batched->need_embd()); - } - - int32_t i_next = 0; - - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - clear_slot(slot); - } - } - - break; - } - } - - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; - } - - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); - - continue; // continue loop of n_batch - } - - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx); - - // technically, measuring the time here excludes the sampling time for the last batch - // but on the other hand, we don't want to do too many system calls to measure the time, so it's ok - const int64_t t_current = ggml_time_us(); - - for (auto & slot : slots) { - // may need to copy state to other slots - if (slot.state == SLOT_STATE_DONE_PROMPT && slot.is_parent()) { - std::vector child_slots; - for (auto & other : slots) { - if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { - child_slots.push_back(&other); - } - } - - // we can only proceed if all child slots are having the correct tasks - if (child_slots.size() == slot.task->n_children) { - // copy state to the child slots - for (auto & child : child_slots) { - SLT_INF(slot, "copying state to child %d\n", child->id); - slot.copy_state_to(*child); - child->state = SLOT_STATE_DONE_PROMPT; - } - } - } - - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); - } - } - - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots - } - - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots - } - - if (slot.i_batch_dft.size() > 0) { - continue; // sample using speculative decoding - } - - const int tok_idx = slot.i_batch - i; - - llama_token id = common_sampler_sample(slot.smpl, ctx, tok_idx); - - slot.i_batch = -1; - - common_sampler_accept(slot.smpl, id, true); - - slot.n_decoded += 1; - - if (slot.n_decoded == 1) { - slot.t_start_generation = t_current; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - - if (!process_token(result, slot)) { - // release slot because of stop condition - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - continue; - } - } - - // speculative decoding - main model sample and accept - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || slot.i_batch_dft.empty()) { - continue; - } - - size_t n_draft = slot.drafted.size(); - - // the accepted tokens from the speculation - const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, slot.i_batch_dft, slot.drafted); - slot.i_batch_dft.clear(); - slot.drafted.clear(); - - slot.n_decoded += ids.size(); - - slot.t_token_generation = std::max(1, t_current - slot.t_start_generation) / 1e3; - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - - // rollback to the state before sampling the draft tokens - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); - - // add accepted tokens to the prompt - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - slot.sampled = ids.back(); // last accepted token - - llama_memory_seq_rm(llama_get_memory(ctx), slot.id, slot.prompt.n_tokens(), -1); - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(ctx, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) slot.drafted.size(), slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); - } - - json model_meta() const { - return json { - {"vocab_type", llama_vocab_type (vocab)}, - {"n_vocab", llama_vocab_n_tokens (vocab)}, - {"n_ctx_train", llama_model_n_ctx_train(model)}, - {"n_embd", llama_model_n_embd (model)}, - {"n_params", llama_model_n_params (model)}, - {"size", llama_model_size (model)}, - }; - } - - int get_slot_n_ctx() { - return slots.back().n_ctx; - } - - server_response_reader get_response_reader() { - return server_response_reader(queue_tasks, queue_results, HTTP_POLLING_SECONDS); - } -}; - -// -// server_context (public API) -// - -server_context::server_context() : impl(new server_context_impl()) {} -server_context::~server_context() = default; - -void server_context::init() { - impl->init(); -} - -bool server_context::load_model(const common_params & params) { - return impl->load_model(params); -} - -void server_context::start_loop() { - impl->queue_tasks.start_loop(); -} - -void server_context::terminate() { - impl->queue_tasks.terminate(); -} - -llama_context * server_context::get_llama_context() const { - return impl->ctx; -} - -server_response_reader server_context::get_response_reader() { - return impl->get_response_reader(); -} - -server_context_info server_context::get_info() const { - return server_context_info { - /* build_info */ build_info, - /* model_name */ impl->model_name, - /* has_inp_image */ impl->oai_parser_opt.allow_image, - /* has_inp_audio */ impl->oai_parser_opt.allow_audio, - }; -} - - - -// generator-like API for HTTP response generation -struct server_res_generator : server_http_res { - server_response_reader rd; - server_res_generator(server_context_impl & ctx_server) - : rd(ctx_server.queue_tasks, ctx_server.queue_results, HTTP_POLLING_SECONDS) {} - void ok(const json & response_data) { - status = 200; - data = safe_json_to_str(response_data); - } - void error(const json & error_data) { - status = json_value(error_data, "code", 500); - data = safe_json_to_str({{ "error", error_data }}); - } -}; - - - -// -// server_routes -// - -static std::unique_ptr handle_completions_impl( - server_context_impl & ctx_server, - server_task_type type, - const json & data, - const std::vector & files, - const std::function & should_stop, - task_response_type res_type) { - GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); - - auto res = std::make_unique(ctx_server); - auto completion_id = gen_chatcmplid(); - auto & rd = res->rd; - - try { - std::vector tasks; - - const auto & prompt = data.at("prompt"); - // TODO: this log can become very long, put it behind a flag or think about a more compact format - //SRV_DBG("Prompt: %s\n", prompt.is_string() ? prompt.get().c_str() : prompt.dump(2).c_str()); - - // process prompt - std::vector inputs; - - if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) { - // This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below. - inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get(), files)); - } else { - // Everything else, including multimodal completions. - inputs = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - } - tasks.reserve(inputs.size()); - int idx = 0; - for (size_t i = 0; i < inputs.size(); i++) { - server_task task = server_task(type); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = idx++; - - task.tokens = std::move(inputs[i]); - task.params = server_task::params_from_json_cmpl( - ctx_server.ctx, - ctx_server.params_base, - data); - task.id_slot = json_value(data, "id_slot", -1); - - // OAI-compat - task.params.res_type = res_type; - task.params.oaicompat_cmpl_id = completion_id; - task.params.oaicompat_model = ctx_server.model_name; - - if (task.params.n_cmpl > 1) { - task.n_children = task.params.n_cmpl - 1; - for (size_t j = 0; j < task.n_children; j++) { - server_task child = task.create_child( - task.id, - ctx_server.queue_tasks.get_new_id(), - idx++); - tasks.push_back(std::move(child)); - } - } - - tasks.push_back(std::move(task)); - } - - rd.post_tasks(std::move(tasks)); - } catch (const std::exception & e) { - res->error(format_error_response(e.what(), ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool stream = json_value(data, "stream", false); - - if (!stream) { - // non-stream, wait for the results - auto all_results = rd.wait_for_all(should_stop); - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - json arr = json::array(); - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - arr.push_back(res->to_json()); - } - GGML_ASSERT(!arr.empty() && "empty results"); - if (arr.size() == 1) { - // if single request, return single object instead of array - res->ok(arr[0]); - } else if (res_type == TASK_RESPONSE_TYPE_OAI_CHAT || res_type == TASK_RESPONSE_TYPE_OAI_CMPL) { - // if multiple results in OAI format, we need to re-format them - json & choices = arr[0]["choices"]; - for (size_t i = 1; i < arr.size(); i++) { - choices.push_back(std::move(arr[i]["choices"][0])); - } - res->ok(arr[0]); - } else { - // multi-results, non-OAI compat - res->ok(arr); - } - } - } else { - // in streaming mode, the first error must be treated as non-stream response - // this is to match the OAI API behavior - // ref: https://github.com/ggml-org/llama.cpp/pull/16486#discussion_r2419657309 - server_task_result_ptr first_result = rd.next(should_stop); - if (first_result == nullptr) { - return res; // connection is closed - } else if (first_result->is_error()) { - res->error(first_result->to_json()); - return res; - } else { - GGML_ASSERT( - dynamic_cast(first_result.get()) != nullptr - || dynamic_cast(first_result.get()) != nullptr - ); - } - - // next responses are streamed - // to be sent immediately - json first_result_json = first_result->to_json(); - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - res->data = format_anthropic_sse(first_result_json); - } else { - res->data = format_oai_sse(first_result_json); - } - res->status = 200; - res->content_type = "text/event-stream"; - res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool { - static auto format_error = [](task_response_type res_type, const json & res_json) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - return format_anthropic_sse({ - {"event", "error"}, - {"data", res_json}, - }); - } else { - return format_oai_sse(json {{ "error", res_json }}); - } - }; - - try { - if (should_stop()) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - if (!res_this->data.empty()) { - // flush the first chunk - output = std::move(res_this->data); - res_this->data.clear(); - return true; - } - - server_response_reader & rd = res_this->rd; - - // check if there is more data - if (!rd.has_next()) { - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - // Anthropic doesn't send [DONE], message_stop was already sent - output = ""; - } else if (res_type != TASK_RESPONSE_TYPE_NONE) { - output = "data: [DONE]\n\n"; - } else { - output = ""; - } - SRV_DBG("%s", "all results received, terminating stream\n"); - return false; // no more data, terminate - } - - // receive subsequent results - auto result = rd.next(should_stop); - if (result == nullptr) { - SRV_DBG("%s", "stopping streaming due to should_stop condition\n"); - return false; // should_stop condition met - } - - // send the results - if (result->is_error()) { - json res_json = result->to_json(); - output = format_error(res_type, res_json); - SRV_DBG("%s", "error received during streaming, terminating stream\n"); - return false; // terminate on error - } else { - GGML_ASSERT( - dynamic_cast(result.get()) != nullptr - || dynamic_cast(result.get()) != nullptr - ); - json res_json = result->to_json(); - if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) { - output = format_anthropic_sse(res_json); - } else { - output = format_oai_sse(res_json); - } - } - - // has next data, continue - return true; - - } catch (const std::exception & e) { - json error_json = format_error_response(e.what(), ERROR_TYPE_SERVER); - output = format_error(res_type, error_json); - - // terminate on exception - return false; - } - }; - } - - return res; -} - -void server_routes::init_routes() { - this->get_health = [this](const server_http_req &) { - // error and loading states are handled by middleware - auto res = std::make_unique(ctx_server); - res->ok({{"status", "ok"}}); - return res; - }; - - this->get_metrics = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_metrics) { - res->error(format_error_response("This server does not support metrics endpoint. Start it with `--metrics`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - // TODO: use server_response_reader - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // metrics definition: https://prometheus.io/docs/practices/naming/#metric-names - json all_metrics_def = json { - {"counter", {{ - {"name", "prompt_tokens_total"}, - {"help", "Number of prompt tokens processed."}, - {"value", (uint64_t) res_task->n_prompt_tokens_processed_total} - }, { - {"name", "prompt_seconds_total"}, - {"help", "Prompt process time"}, - {"value", (uint64_t) res_task->t_prompt_processing_total / 1.e3} - }, { - {"name", "tokens_predicted_total"}, - {"help", "Number of generation tokens processed."}, - {"value", (uint64_t) res_task->n_tokens_predicted_total} - }, { - {"name", "tokens_predicted_seconds_total"}, - {"help", "Predict process time"}, - {"value", (uint64_t) res_task->t_tokens_generation_total / 1.e3} - }, { - {"name", "n_decode_total"}, - {"help", "Total number of llama_decode() calls"}, - {"value", res_task->n_decode_total} - }, { - {"name", "n_tokens_max"}, - {"help", "Largest observed n_tokens."}, - {"value", res_task->n_tokens_max} - }, { - {"name", "n_busy_slots_per_decode"}, - {"help", "Average number of busy slots per llama_decode() call"}, - {"value", (float) res_task->n_busy_slots_total / std::max((float) res_task->n_decode_total, 1.f)} - }}}, - {"gauge", {{ - {"name", "prompt_tokens_seconds"}, - {"help", "Average prompt throughput in tokens/s."}, - {"value", res_task->n_prompt_tokens_processed ? 1.e3 / res_task->t_prompt_processing * res_task->n_prompt_tokens_processed : 0.} - },{ - {"name", "predicted_tokens_seconds"}, - {"help", "Average generation throughput in tokens/s."}, - {"value", res_task->n_tokens_predicted ? 1.e3 / res_task->t_tokens_generation * res_task->n_tokens_predicted : 0.} - },{ - {"name", "requests_processing"}, - {"help", "Number of requests processing."}, - {"value", (uint64_t) res_task->n_processing_slots} - },{ - {"name", "requests_deferred"}, - {"help", "Number of requests deferred."}, - {"value", (uint64_t) res_task->n_tasks_deferred} - }}} - }; - - std::stringstream prometheus; - - for (const auto & el : all_metrics_def.items()) { - const auto & type = el.key(); - const auto & metrics_def = el.value(); - - for (const auto & metric_def : metrics_def) { - const std::string name = metric_def.at("name"); - const std::string help = metric_def.at("help"); - - auto value = json_value(metric_def, "value", 0.); - prometheus << "# HELP llamacpp:" << name << " " << help << "\n" - << "# TYPE llamacpp:" << name << " " << type << "\n" - << "llamacpp:" << name << " " << value << "\n"; - } - } - - res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start); - res->content_type = "text/plain; version=0.0.4"; - res->status = 200; - res->data = prometheus.str(); - return res; - }; - - this->get_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_slots) { - res->error(format_error_response("This server does not support slots endpoint. Start it with `--slots`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // request slots data using task queue - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_METRICS); - task.id = task_id; - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task), true); // high-priority task - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - // TODO: get rid of this dynamic_cast - auto res_task = dynamic_cast(result.get()); - GGML_ASSERT(res_task != nullptr); - - // optionally return "fail_on_no_slot" error - if (!req.get_param("fail_on_no_slot").empty()) { - if (res_task->n_idle_slots == 0) { - res->error(format_error_response("no slot available", ERROR_TYPE_UNAVAILABLE)); - return res; - } - } - - res->ok(res_task->slots_data); - return res; - }; - - this->post_slots = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (params.slot_save_path.empty()) { - res->error(format_error_response("This server does not support slots action. Start it with `--slot-save-path`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - std::string id_slot_str = req.get_param("id_slot"); - int id_slot; - - try { - id_slot = std::stoi(id_slot_str); - } catch (const std::exception &) { - res->error(format_error_response("Invalid slot ID", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::string action = req.get_param("action"); - - if (action == "save") { - return handle_slots_save(req, id_slot); - } else if (action == "restore") { - return handle_slots_restore(req, id_slot); - } else if (action == "erase") { - return handle_slots_erase(req, id_slot); - } else { - res->error(format_error_response("Invalid action", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - }; - - this->get_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json default_generation_settings_for_props; - - { - task_params params; - - params.sampling = ctx_server.params_base.sampling; - - default_generation_settings_for_props = json { - {"params", params.to_json(true)}, - {"n_ctx", ctx_server.get_slot_n_ctx()}, - }; - } - - // this endpoint is publicly available, please only return what is safe to be exposed - json data = { - { "default_generation_settings", default_generation_settings_for_props }, - { "total_slots", ctx_server.params_base.n_parallel }, - { "model_alias", ctx_server.model_name }, - { "model_path", ctx_server.params_base.model.path }, - { "modalities", json { - {"vision", ctx_server.oai_parser_opt.allow_image}, - {"audio", ctx_server.oai_parser_opt.allow_audio}, - } }, - { "endpoint_slots", params.endpoint_slots }, - { "endpoint_props", params.endpoint_props }, - { "endpoint_metrics", params.endpoint_metrics }, - { "webui", params.webui }, - { "chat_template", common_chat_templates_source(ctx_server.chat_templates.get()) }, - { "bos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_bos(ctx_server.vocab), /* special= */ true)}, - { "eos_token", common_token_to_piece(ctx_server.ctx, llama_vocab_eos(ctx_server.vocab), /* special= */ true)}, - { "build_info", build_info }, - }; - if (ctx_server.params_base.use_jinja) { - if (auto tool_use_src = common_chat_templates_source(ctx_server.chat_templates.get(), "tool_use")) { - data["chat_template_tool_use"] = tool_use_src; - } - } - - res->ok(data); - return res; - }; - - this->post_props = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - if (!params.endpoint_props) { - res->error(format_error_response("This server does not support changing global properties. Start it with `--props`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - // update any props here - - res->ok({{ "success", true }}); - return res; - }; - - this->get_api_show = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - bool has_mtmd = ctx_server.mctx != nullptr; - json data = { - { - "template", common_chat_templates_source(ctx_server.chat_templates.get()), - }, - { - "model_info", { - { "llama.context_length", ctx_server.get_slot_n_ctx() }, - } - }, - {"modelfile", ""}, - {"parameters", ""}, - {"template", common_chat_templates_source(ctx_server.chat_templates.get())}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }}, - {"model_info", ""}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})} - }; - - res->ok(data); - return res; - }; - - this->post_infill = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - // check model compatibility - std::string err; - if (llama_vocab_fim_pre(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "prefix token is missing. "; - } - if (llama_vocab_fim_suf(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "suffix token is missing. "; - } - if (llama_vocab_fim_mid(ctx_server.vocab) == LLAMA_TOKEN_NULL) { - err += "middle token is missing. "; - } - if (!err.empty()) { - res->error(format_error_response(string_format("Infill is not supported by this model: %s", err.c_str()), ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - // validate input - json data = json::parse(req.body); - if (data.contains("prompt") && !data.at("prompt").is_string()) { - // prompt is optional - res->error(format_error_response("\"prompt\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_prefix")) { - res->error(format_error_response("\"input_prefix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (!data.contains("input_suffix")) { - res->error(format_error_response("\"input_suffix\" is required", ERROR_TYPE_INVALID_REQUEST)); - } - - if (data.contains("input_extra") && !data.at("input_extra").is_array()) { - // input_extra is optional - res->error(format_error_response("\"input_extra\" must be an array of {\"filename\": string, \"text\": string}", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - json input_extra = json_value(data, "input_extra", json::array()); - for (const auto & chunk : input_extra) { - // { "text": string, "filename": string } - if (!chunk.contains("text") || !chunk.at("text").is_string()) { - res->error(format_error_response("extra_context chunk must contain a \"text\" field with a string value", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - // filename is optional - if (chunk.contains("filename") && !chunk.at("filename").is_string()) { - res->error(format_error_response("extra_context chunk's \"filename\" field must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - data["input_extra"] = input_extra; // default to empty array if it's not exist - - std::string prompt = json_value(data, "prompt", std::string()); - std::vector tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, false, true); - SRV_DBG("creating infill tasks, n_prompts = %d\n", (int) tokenized_prompts.size()); - data["prompt"] = format_prompt_infill( - ctx_server.vocab, - data.at("input_prefix"), - data.at("input_suffix"), - data.at("input_extra"), - ctx_server.params_base.n_batch, - ctx_server.params_base.n_predict, - ctx_server.get_slot_n_ctx(), - ctx_server.params_base.spm_infill, - tokenized_prompts[0].get_text_tokens() // TODO: this could maybe be multimodal. - ); - - std::vector files; // dummy - return handle_completions_impl( - ctx_server, - SERVER_TASK_TYPE_INFILL, - data, - files, - req.should_stop, - TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible - }; - - this->post_completions = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - ctx_server, - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - TASK_RESPONSE_TYPE_NONE); - }; - - this->post_completions_oai = [this](const server_http_req & req) { - std::vector files; // dummy - const json body = json::parse(req.body); - return handle_completions_impl( - ctx_server, - SERVER_TASK_TYPE_COMPLETION, - body, - files, - req.should_stop, - TASK_RESPONSE_TYPE_OAI_CMPL); - }; - - this->post_chat_completions = [this](const server_http_req & req) { - std::vector files; - json body = json::parse(req.body); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - ctx_server, - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.should_stop, - TASK_RESPONSE_TYPE_OAI_CHAT); - }; - - this->post_anthropic_messages = [this](const server_http_req & req) { - std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - return handle_completions_impl( - ctx_server, - SERVER_TASK_TYPE_COMPLETION, - body_parsed, - files, - req.should_stop, - TASK_RESPONSE_TYPE_ANTHROPIC); - }; - - this->post_anthropic_count_tokens = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - std::vector files; - json body = convert_anthropic_to_oai(json::parse(req.body)); - json body_parsed = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - - json prompt = body_parsed.at("prompt"); - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true); - - res->ok({{"input_tokens", static_cast(tokens.size())}}); - return res; - }; - - // same with handle_chat_completions, but without inference part - this->post_apply_template = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - std::vector files; // dummy, unused - json body = json::parse(req.body); - json data = oaicompat_chat_params_parse( - body, - ctx_server.oai_parser_opt, - files); - res->ok({{ "prompt", std::move(data.at("prompt")) }}); - return res; - }; - - this->get_models = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json model_meta = nullptr; - if (is_ready()) { - model_meta = ctx_server.model_meta(); - } - bool has_mtmd = ctx_server.mctx != nullptr; - json models = { - {"models", { - { - {"name", ctx_server.model_name}, - {"model", ctx_server.model_name}, - {"modified_at", ""}, - {"size", ""}, - {"digest", ""}, // dummy value, llama.cpp does not support managing model file's hash - {"type", "model"}, - {"description", ""}, - {"tags", {""}}, - {"capabilities", has_mtmd ? json({"completion","multimodal"}) : json({"completion"})}, - {"parameters", ""}, - {"details", { - {"parent_model", ""}, - {"format", "gguf"}, - {"family", ""}, - {"families", {""}}, - {"parameter_size", ""}, - {"quantization_level", ""} - }} - } - }}, - {"object", "list"}, - {"data", { - { - {"id", ctx_server.model_name}, - {"object", "model"}, - {"created", std::time(0)}, - {"owned_by", "llamacpp"}, - {"meta", model_meta}, - }, - }} - }; - - res->ok(models); - return res; - }; - - this->post_tokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - json tokens_response = json::array(); - if (body.count("content") != 0) { - const bool add_special = json_value(body, "add_special", false); - const bool parse_special = json_value(body, "parse_special", true); - const bool with_pieces = json_value(body, "with_pieces", false); - - llama_tokens tokens = tokenize_mixed(ctx_server.vocab, body.at("content"), add_special, parse_special); - - if (with_pieces) { - for (const auto& token : tokens) { - std::string piece = common_token_to_piece(ctx_server.ctx, token); - json piece_json; - - // Check if the piece is valid UTF-8 - if (is_valid_utf8(piece)) { - piece_json = piece; - } else { - // If not valid UTF-8, store as array of byte values - piece_json = json::array(); - for (unsigned char c : piece) { - piece_json.push_back(static_cast(c)); - } - } - - tokens_response.push_back({ - {"id", token}, - {"piece", piece_json} - }); - } - } else { - tokens_response = tokens; - } - } - - res->ok(json{{"tokens", std::move(tokens_response)}}); - return res; - }; - - this->post_detokenize = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - - std::string content; - if (body.count("tokens") != 0) { - const llama_tokens tokens = body.at("tokens"); - content = tokens_to_str(ctx_server.ctx, tokens); - } - - res->ok(json{{"content", std::move(content)}}); - return res; - }; - - this->post_embeddings = [this](const server_http_req & req) { - return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE); - }; - - this->post_embeddings_oai = [this](const server_http_req & req) { - return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD); - }; - - this->post_rerank = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding || ctx_server.params_base.pooling_type != LLAMA_POOLING_TYPE_RANK) { - res->error(format_error_response("This server does not support reranking. Start it with `--reranking`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - const json body = json::parse(req.body); - - // if true, use TEI API format, otherwise use Jina API format - // Jina: https://jina.ai/reranker/ - // TEI: https://huggingface.github.io/text-embeddings-inference/#/Text%20Embeddings%20Inference/rerank - bool is_tei_format = body.contains("texts"); - - json query; - if (body.count("query") == 1) { - query = body.at("query"); - if (!query.is_string()) { - res->error(format_error_response("\"query\" must be a string", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } else { - res->error(format_error_response("\"query\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - std::vector documents = json_value(body, "documents", - json_value(body, "texts", std::vector())); - if (documents.empty()) { - res->error(format_error_response("\"documents\" must be a non-empty string array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int top_n = json_value(body, "top_n", (int)documents.size()); - - // create and queue the task - json responses = json::array(); - server_response_reader rd = ctx_server.get_response_reader(); - { - std::vector tasks; - tasks.reserve(documents.size()); - for (size_t i = 0; i < documents.size(); i++) { - auto tmp = format_prompt_rerank(ctx_server.model, ctx_server.vocab, ctx_server.mctx, query, documents[i]); - server_task task = server_task(SERVER_TASK_TYPE_RERANK); - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tmp); - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = format_response_rerank( - body, - ctx_server.model_name, - responses, - is_tei_format, - documents, - top_n); - - res->ok(root); - return res; - }; - - this->get_lora_adapters = [this](const server_http_req &) { - auto res = std::make_unique(ctx_server); - json result = json::array(); - const auto & loras = ctx_server.params_base.lora_adapters; - for (size_t i = 0; i < loras.size(); ++i) { - auto & lora = loras[i]; - json entry = { - {"id", i}, - {"path", lora.path}, - {"scale", lora.scale}, - {"task_name", lora.task_name}, - {"prompt_prefix", lora.prompt_prefix}, - }; - std::string alora_invocation_string = ""; - const uint64_t n_alora_tokens = llama_adapter_get_alora_n_invocation_tokens(lora.ptr); - std::vector alora_invocation_tokens; - if (n_alora_tokens) { - const llama_token * alora_tokens = llama_adapter_get_alora_invocation_tokens(lora.ptr); - for (uint64_t i = 0; i < n_alora_tokens; ++i) { - alora_invocation_string += common_token_to_piece(ctx_server.ctx, alora_tokens[i]); - alora_invocation_tokens.push_back(alora_tokens[i]); - } - entry["alora_invocation_string"] = alora_invocation_string; - entry["alora_invocation_tokens"] = alora_invocation_tokens; - } - result.push_back(std::move(entry)); - } - res->ok(result); - return res; - }; - - this->post_lora_adapters = [this](const server_http_req & req) { - auto res = std::make_unique(ctx_server); - const json body = json::parse(req.body); - if (!body.is_array()) { - res->error(format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = task_id; - task.set_lora = parse_lora_request(ctx_server.params_base.lora_adapters, body); - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - // get the result - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; - }; -} - -std::unique_ptr server_routes::handle_slots_save(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_SAVE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - res->ok(result->to_json()); - return res; -} - -std::unique_ptr server_routes::handle_slots_restore(const server_http_req & req, int id_slot) { - auto res = std::make_unique(ctx_server); - const json request_data = json::parse(req.body); - std::string filename = request_data.at("filename"); - if (!fs_validate_filename(filename)) { - res->error(format_error_response("Invalid filename", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - std::string filepath = params.slot_save_path + filename; - - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_RESTORE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - task.slot_action.filename = filename; - task.slot_action.filepath = filepath; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; -} - -std::unique_ptr server_routes::handle_slots_erase(const server_http_req &, int id_slot) { - auto res = std::make_unique(ctx_server); - int task_id = ctx_server.queue_tasks.get_new_id(); - { - server_task task(SERVER_TASK_TYPE_SLOT_ERASE); - task.id = task_id; - task.slot_action.slot_id = id_slot; - - // TODO: use server_response_reader - ctx_server.queue_results.add_waiting_task_id(task_id); - ctx_server.queue_tasks.post(std::move(task)); - } - - server_task_result_ptr result = ctx_server.queue_results.recv(task_id); - ctx_server.queue_results.remove_waiting_task_id(task_id); - - if (result->is_error()) { - res->error(result->to_json()); - return res; - } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res->ok(result->to_json()); - return res; -} - -std::unique_ptr server_routes::handle_embeddings_impl(const server_http_req & req, task_response_type res_type) { - auto res = std::make_unique(ctx_server); - if (!ctx_server.params_base.embedding) { - res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED)); - return res; - } - - if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - const json body = json::parse(req.body); - - // for the shape of input/content, see tokenize_input_prompts() - json prompt; - if (body.count("input") != 0) { - prompt = body.at("input"); - } else if (body.contains("content")) { - res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible - prompt = body.at("content"); - } else { - res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - - bool use_base64 = false; - if (body.count("encoding_format") != 0) { - const std::string& format = body.at("encoding_format"); - if (format == "base64") { - use_base64 = true; - } else if (format != "float") { - res->error(format_error_response("The format to return the embeddings in. Can be either float or base64", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - auto tokenized_prompts = tokenize_input_prompts(ctx_server.vocab, ctx_server.mctx, prompt, true, true); - for (const auto & tokens : tokenized_prompts) { - // this check is necessary for models that do not add BOS token to the input - if (tokens.empty()) { - res->error(format_error_response("Input content cannot be empty", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - } - - int embd_normalize = 2; // default to Euclidean/L2 norm - if (body.count("embd_normalize") != 0) { - embd_normalize = body.at("embd_normalize"); - if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { - SRV_DBG("embd_normalize is not supported by pooling type %d, ignoring it\n", llama_pooling_type(ctx_server.ctx)); - } - } - - // create and queue the task - json responses = json::array(); - server_response_reader rd = ctx_server.get_response_reader(); - { - std::vector tasks; - for (size_t i = 0; i < tokenized_prompts.size(); i++) { - server_task task = server_task(SERVER_TASK_TYPE_EMBEDDING); - - task.id = ctx_server.queue_tasks.get_new_id(); - task.index = i; - task.tokens = std::move(tokenized_prompts[i]); - - // OAI-compat - task.params.res_type = res_type; - task.params.embd_normalize = embd_normalize; - - tasks.push_back(std::move(task)); - } - rd.post_tasks(std::move(tasks)); - } - - // wait for the results - auto all_results = rd.wait_for_all(req.should_stop); - - // collect results - if (all_results.is_terminated) { - return res; // connection is closed - } else if (all_results.error) { - res->error(all_results.error->to_json()); - return res; - } else { - for (auto & res : all_results.results) { - GGML_ASSERT(dynamic_cast(res.get()) != nullptr); - responses.push_back(res->to_json()); - } - } - - // write JSON response - json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? format_embeddings_response_oaicompat(body, ctx_server.model_name, responses, use_base64) - : json(responses); - res->ok(root); - return res; -} diff --git a/llamacpp/native/src/server/server-context.h b/llamacpp/native/src/server/server-context.h deleted file mode 100644 index 230b25952..000000000 --- a/llamacpp/native/src/server/server-context.h +++ /dev/null @@ -1,93 +0,0 @@ -#include "server-http.h" -#include "server-task.h" -#include "server-queue.h" - -#include - -#include -#include - -struct server_context_impl; // private implementation - -struct server_context_info { - std::string build_info; - std::string model_name; - bool has_inp_image; - bool has_inp_audio; -}; - -struct server_context { - std::unique_ptr impl; - - server_context(); - ~server_context(); - - // initialize slots and server-related data - void init(); - - // load the model and initialize llama_context - // returns true on success - bool load_model(const common_params & params); - - // this function will block main thread until termination - void start_loop(); - - // terminate main loop (will unblock start_loop) - void terminate(); - - // get the underlaying llama_context - llama_context * get_llama_context() const; - - // get a new response reader, used by CLI application - server_response_reader get_response_reader(); - - // get server info - // used by CLI application - server_context_info get_info() const; -}; - - -// forward declarations -struct server_res_generator; - -struct server_routes { - server_routes(const common_params & params, server_context & ctx_server, std::function is_ready = []() { return true; }) - : params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) { - init_routes(); - } - - void init_routes(); - // handlers using lambda function, so that they can capture `this` without `std::bind` - server_http_context::handler_t get_health; - server_http_context::handler_t get_metrics; - server_http_context::handler_t get_slots; - server_http_context::handler_t post_slots; - server_http_context::handler_t get_props; - server_http_context::handler_t post_props; - server_http_context::handler_t get_api_show; - server_http_context::handler_t post_infill; - server_http_context::handler_t post_completions; - server_http_context::handler_t post_completions_oai; - server_http_context::handler_t post_chat_completions; - server_http_context::handler_t post_anthropic_messages; - server_http_context::handler_t post_anthropic_count_tokens; - server_http_context::handler_t post_apply_template; - server_http_context::handler_t get_models; - server_http_context::handler_t post_tokenize; - server_http_context::handler_t post_detokenize; - server_http_context::handler_t post_embeddings; - server_http_context::handler_t post_embeddings_oai; - server_http_context::handler_t post_rerank; - server_http_context::handler_t get_lora_adapters; - server_http_context::handler_t post_lora_adapters; -private: - // TODO: move these outside of server_routes? - std::unique_ptr handle_slots_save(const server_http_req & req, int id_slot); - std::unique_ptr handle_slots_restore(const server_http_req & req, int id_slot); - std::unique_ptr handle_slots_erase(const server_http_req &, int id_slot); - std::unique_ptr handle_embeddings_impl(const server_http_req & req, task_response_type res_type); - - const common_params & params; - server_context_impl & ctx_server; - std::function is_ready; -}; diff --git a/llamacpp/native/src/server/server-http.cpp b/llamacpp/native/src/server/server-http.cpp deleted file mode 100644 index 77e54d192..000000000 --- a/llamacpp/native/src/server/server-http.cpp +++ /dev/null @@ -1,380 +0,0 @@ -#include "common.h" -#include "server-http.h" -#include "server-common.h" - -#include - -#include -#include -#include - -// -// HTTP implementation using cpp-httplib -// - -class server_http_context::Impl { -public: - std::unique_ptr srv; -}; - -server_http_context::server_http_context() - : pimpl(std::make_unique()) -{} - -server_http_context::~server_http_context() = default; - -static void log_server_request(const httplib::Request & req, const httplib::Response & res) { - // skip GH copilot requests when using default port - if (req.path == "/v1/health") { - return; - } - - // reminder: this function is not covered by httplib's exception handler; if someone does more complicated stuff, think about wrapping it in try-catch - - SRV_INF("request: %s %s %s %d\n", req.method.c_str(), req.path.c_str(), req.remote_addr.c_str(), res.status); - - SRV_DBG("request: %s\n", req.body.c_str()); - SRV_DBG("response: %s\n", res.body.c_str()); -} - -bool server_http_context::init(const common_params & params) { - path_prefix = params.api_prefix; - port = params.port; - hostname = params.hostname; - - auto & srv = pimpl->srv; - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_INF("Running with SSL: key = %s, cert = %s\n", params.ssl_file_key.c_str(), params.ssl_file_cert.c_str()); - srv.reset( - new httplib::SSLServer(params.ssl_file_cert.c_str(), params.ssl_file_key.c_str()) - ); - } else { - LOG_INF("Running without SSL\n"); - srv.reset(new httplib::Server()); - } -#else - if (params.ssl_file_key != "" && params.ssl_file_cert != "") { - LOG_ERR("Server is built without SSL support\n"); - return false; - } - srv.reset(new httplib::Server()); -#endif - - srv->set_default_headers({{"Server", "llama.cpp"}}); - srv->set_logger(log_server_request); - srv->set_exception_handler([](const httplib::Request &, httplib::Response & res, const std::exception_ptr & ep) { - // this is fail-safe; exceptions should already handled by `ex_wrapper` - - std::string message; - try { - std::rethrow_exception(ep); - } catch (const std::exception & e) { - message = e.what(); - } catch (...) { - message = "Unknown Exception"; - } - - res.status = 500; - res.set_content(message, "text/plain"); - LOG_ERR("got exception: %s\n", message.c_str()); - }); - - srv->set_error_handler([](const httplib::Request &, httplib::Response & res) { - if (res.status == 404) { - res.set_content( - safe_json_to_str(json { - {"error", { - {"message", "File Not Found"}, - {"type", "not_found_error"}, - {"code", 404} - }} - }), - "application/json; charset=utf-8" - ); - } - // for other error codes, we skip processing here because it's already done by res->error() - }); - - // set timeouts and change hostname and port - srv->set_read_timeout (params.timeout_read); - srv->set_write_timeout(params.timeout_write); - - if (params.api_keys.size() == 1) { - auto key = params.api_keys[0]; - std::string substr = key.substr(std::max((int)(key.length() - 4), 0)); - LOG_INF("%s: api_keys: ****%s\n", __func__, substr.c_str()); - } else if (params.api_keys.size() > 1) { - LOG_INF("%s: api_keys: %zu keys loaded\n", __func__, params.api_keys.size()); - } - - // - // Middlewares - // - - auto middleware_validate_api_key = [api_keys = params.api_keys](const httplib::Request & req, httplib::Response & res) { - static const std::unordered_set public_endpoints = { - "/health", - "/v1/health", - "/models", - "/v1/models", - "/api/tags" - }; - - // If API key is not set, skip validation - if (api_keys.empty()) { - return true; - } - - // If path is public or is static file, skip validation - if (public_endpoints.find(req.path) != public_endpoints.end() || req.path == "/") { - return true; - } - - // Check for API key in the Authorization header - std::string req_api_key = req.get_header_value("Authorization"); - if (req_api_key.empty()) { - // retry with anthropic header - req_api_key = req.get_header_value("X-Api-Key"); - } - - // remove the "Bearer " prefix if needed - std::string prefix = "Bearer "; - if (req_api_key.substr(0, prefix.size()) == prefix) { - req_api_key = req_api_key.substr(prefix.size()); - } - - // validate the API key - if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) { - return true; // API key is valid - } - - // API key is invalid or not provided - res.status = 401; - res.set_content( - safe_json_to_str(json { - {"error", { - {"message", "Invalid API Key"}, - {"type", "authentication_error"}, - {"code", 401} - }} - }), - "application/json; charset=utf-8" - ); - - LOG_WRN("Unauthorized: Invalid API Key\n"); - - return false; - }; - - auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { - bool ready = is_ready.load(); - if (!ready) { - res.status = 503; - res.set_content( - safe_json_to_str(json { - {"error", { - {"message", "Loading model"}, - {"type", "unavailable_error"}, - {"code", 503} - }} - }), - "application/json; charset=utf-8" - ); - return false; - } - return true; - }; - - // register server middlewares - srv->set_pre_routing_handler([middleware_validate_api_key, middleware_server_state](const httplib::Request & req, httplib::Response & res) { - res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - // If this is OPTIONS request, skip validation because browsers don't include Authorization header - if (req.method == "OPTIONS") { - res.set_header("Access-Control-Allow-Credentials", "true"); - res.set_header("Access-Control-Allow-Methods", "GET, POST"); - res.set_header("Access-Control-Allow-Headers", "*"); - res.set_content("", "text/html"); // blank response, no data - return httplib::Server::HandlerResponse::Handled; // skip further processing - } - if (!middleware_server_state(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - if (!middleware_validate_api_key(req, res)) { - return httplib::Server::HandlerResponse::Handled; - } - return httplib::Server::HandlerResponse::Unhandled; - }); - - int n_threads_http = params.n_threads_http; - if (n_threads_http < 1) { - // +2 threads for monitoring endpoints - n_threads_http = std::max(params.n_parallel + 2, (int32_t) std::thread::hardware_concurrency() - 1); - } - LOG_INF("%s: using %d threads for HTTP server\n", __func__, n_threads_http); - srv->new_task_queue = [n_threads_http] { return new httplib::ThreadPool(n_threads_http); }; - - // - // Web UI setup - // - - if (!params.webui) { - LOG_INF("Web UI is disabled\n"); - } else { - // register static assets routes - if (!params.public_path.empty()) { - // Set the base directory for serving static files - bool is_found = srv->set_mount_point(params.api_prefix + "/", params.public_path); - if (!is_found) { - LOG_ERR("%s: static assets path not found: %s\n", __func__, params.public_path.c_str()); - return 1; - } - } else { - // using embedded static index.html - srv->Get(params.api_prefix + "/", [](const httplib::Request & req, httplib::Response & res) { - if (req.get_header_value("Accept-Encoding").find("gzip") == std::string::npos) { - res.set_content("Error: gzip is not supported by this browser", "text/plain"); - } else { - res.set_header("Content-Encoding", "gzip"); - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); - } - return false; - }); - } - } - return true; -} - -bool server_http_context::start() { - // Bind and listen - - auto & srv = pimpl->srv; - bool was_bound = false; - bool is_sock = false; - if (string_ends_with(std::string(hostname), ".sock")) { - is_sock = true; - LOG_INF("%s: setting address family to AF_UNIX\n", __func__); - srv->set_address_family(AF_UNIX); - // bind_to_port requires a second arg, any value other than 0 should - // simply get ignored - was_bound = srv->bind_to_port(hostname, 8080); - } else { - LOG_INF("%s: binding port with default address family\n", __func__); - // bind HTTP listen port - if (port == 0) { - int bound_port = srv->bind_to_any_port(hostname); - was_bound = (bound_port >= 0); - if (was_bound) { - port = bound_port; - } - } else { - was_bound = srv->bind_to_port(hostname, port); - } - } - - if (!was_bound) { - LOG_ERR("%s: couldn't bind HTTP server socket, hostname: %s, port: %d\n", __func__, hostname.c_str(), port); - return false; - } - - // run the HTTP server in a thread - thread = std::thread([this]() { pimpl->srv->listen_after_bind(); }); - srv->wait_until_ready(); - - listening_address = is_sock ? string_format("unix://%s", hostname.c_str()) - : string_format("http://%s:%d", hostname.c_str(), port); - return true; -} - -void server_http_context::stop() const { - if (pimpl->srv) { - pimpl->srv->stop(); - } -} - -static void set_headers(httplib::Response & res, const std::map & headers) { - for (const auto & [key, value] : headers) { - res.set_header(key, value); - } -} - -static std::map get_params(const httplib::Request & req) { - std::map params; - for (const auto & [key, value] : req.params) { - params[key] = value; - } - for (const auto & [key, value] : req.path_params) { - params[key] = value; - } - return params; -} - -static std::map get_headers(const httplib::Request & req) { - std::map headers; - for (const auto & [key, value] : req.headers) { - headers[key] = value; - } - return headers; -} - -static void process_handler_response(server_http_res_ptr & response, httplib::Response & res) { - if (response->is_stream()) { - res.status = response->status; - set_headers(res, response->headers); - std::string content_type = response->content_type; - // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it - std::shared_ptr r_ptr = std::move(response); - const auto chunked_content_provider = [response = r_ptr](size_t, httplib::DataSink & sink) -> bool { - std::string chunk; - bool has_next = response->next(chunk); - if (!chunk.empty()) { - // TODO: maybe handle sink.write unsuccessful? for now, we rely on is_connection_closed() - sink.write(chunk.data(), chunk.size()); - SRV_DBG("http: streamed chunk: %s\n", chunk.c_str()); - } - if (!has_next) { - sink.done(); - SRV_DBG("%s", "http: stream ended\n"); - } - return has_next; - }; - const auto on_complete = [response = r_ptr](bool) mutable { - response.reset(); // trigger the destruction of the response object - }; - res.set_chunked_content_provider(content_type, chunked_content_provider, on_complete); - } else { - res.status = response->status; - set_headers(res, response->headers); - res.set_content(response->data, response->content_type); - } -} - -void server_http_context::get(const std::string & path, const server_http_context::handler_t & handler) const { - pimpl->srv->Get(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { - server_http_res_ptr response = handler(server_http_req{ - get_params(req), - get_headers(req), - req.path, - req.body, - req.is_connection_closed - }); - process_handler_response(response, res); - }); -} - -void server_http_context::post(const std::string & path, const server_http_context::handler_t & handler) const { - pimpl->srv->Post(path_prefix + path, [handler](const httplib::Request & req, httplib::Response & res) { - server_http_res_ptr response = handler(server_http_req{ - get_params(req), - get_headers(req), - req.path, - req.body, - req.is_connection_closed - }); - process_handler_response(response, res); - }); -} - diff --git a/llamacpp/native/src/server/server-http.h b/llamacpp/native/src/server/server-http.h deleted file mode 100644 index 24c0b4011..000000000 --- a/llamacpp/native/src/server/server-http.h +++ /dev/null @@ -1,78 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include - -struct common_params; - -// generator-like API for HTTP response generation -// this object response with one of the 2 modes: -// 1) normal response: `data` contains the full response body -// 2) streaming response: each call to next(output) generates the next chunk -// when next(output) returns false, no more data after the current chunk -// note: some chunks can be empty, in which case no data is sent for that chunk -struct server_http_res { - std::string content_type = "application/json; charset=utf-8"; - int status = 200; - std::string data; - std::map headers; - - // TODO: move this to a virtual function once we have proper polymorphism support - std::function next = nullptr; - bool is_stream() const { - return next != nullptr; - } - - virtual ~server_http_res() = default; -}; - -// unique pointer, used by set_chunked_content_provider -// httplib requires the stream provider to be stored in heap -using server_http_res_ptr = std::unique_ptr; - -struct server_http_req { - std::map params; // path_params + query_params - std::map headers; // reserved for future use - std::string path; // reserved for future use - std::string body; - const std::function & should_stop; - - std::string get_param(const std::string & key, const std::string & def = "") const { - auto it = params.find(key); - if (it != params.end()) { - return it->second; - } - return def; - } -}; - -struct server_http_context { - class Impl; - std::unique_ptr pimpl; - - std::thread thread; // server thread - std::atomic is_ready = false; - - std::string path_prefix; - std::string hostname; - int port; - - server_http_context(); - ~server_http_context(); - - bool init(const common_params & params); - bool start(); - void stop() const; - - // note: the handler should never throw exceptions - using handler_t = std::function; - - void get(const std::string & path, const handler_t & handler) const; - void post(const std::string & path, const handler_t & handler) const; - - // for debugging - std::string listening_address; -}; diff --git a/llamacpp/native/src/server/server-http.patch b/llamacpp/native/src/server/server-http.patch deleted file mode 100644 index 900dae89b..000000000 --- a/llamacpp/native/src/server/server-http.patch +++ /dev/null @@ -1,61 +0,0 @@ -diff --git a/llamacpp/native/src/server/server-http.cpp b/llamacpp/native/src/server/server-http.cpp -index 62250571..77e54d19 100644 ---- a/llamacpp/native/src/server/server-http.cpp -+++ b/llamacpp/native/src/server/server-http.cpp -@@ -8,10 +8,6 @@ - #include - #include - --// auto generated files (see README.md for details) --#include "index.html.gz.hpp" --#include "loading.html.hpp" -- - // - // HTTP implementation using cpp-httplib - // -@@ -175,26 +171,17 @@ bool server_http_context::init(const common_params & params) { - auto middleware_server_state = [this](const httplib::Request & req, httplib::Response & res) { - bool ready = is_ready.load(); - if (!ready) { -- auto tmp = string_split(req.path, '.'); -- if (req.path == "/" || tmp.back() == "html") { -- res.set_content(reinterpret_cast(loading_html), loading_html_len, "text/html; charset=utf-8"); -- res.status = 503; -- } else if (req.path == "/models" || req.path == "/v1/models" || req.path == "/api/tags") { -- // allow the models endpoint to be accessed during loading -- return true; -- } else { -- res.status = 503; -- res.set_content( -- safe_json_to_str(json { -- {"error", { -- {"message", "Loading model"}, -- {"type", "unavailable_error"}, -- {"code", 503} -- }} -- }), -- "application/json; charset=utf-8" -- ); -- } -+ res.status = 503; -+ res.set_content( -+ safe_json_to_str(json { -+ {"error", { -+ {"message", "Loading model"}, -+ {"type", "unavailable_error"}, -+ {"code", 503} -+ }} -+ }), -+ "application/json; charset=utf-8" -+ ); - return false; - } - return true; -@@ -253,7 +240,6 @@ bool server_http_context::init(const common_params & params) { - // COEP and COOP headers, required by pyodide (python interpreter) - res.set_header("Cross-Origin-Embedder-Policy", "require-corp"); - res.set_header("Cross-Origin-Opener-Policy", "same-origin"); -- res.set_content(reinterpret_cast(index_html_gz), index_html_gz_len, "text/html; charset=utf-8"); - } - return false; - }); diff --git a/llamacpp/native/src/server/server-models.cpp b/llamacpp/native/src/server/server-models.cpp deleted file mode 100644 index 6c618a673..000000000 --- a/llamacpp/native/src/server/server-models.cpp +++ /dev/null @@ -1,1109 +0,0 @@ -#include "server-common.h" -#include "server-models.h" - -#include "preset.h" -#include "download.h" - -#include // TODO: remove this once we use HTTP client from download.h -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#ifdef _WIN32 -#include -#else -#include -#include -#include -#include -#endif - -#if defined(__APPLE__) && defined(__MACH__) -// macOS: use _NSGetExecutablePath to get the executable path -#include -#include -#endif - -#define CMD_EXIT "exit" - -// address for child process, this is needed because router may run on 0.0.0.0 -// ref: https://github.com/ggml-org/llama.cpp/issues/17862 -#define CHILD_ADDR "127.0.0.1" - -static std::filesystem::path get_server_exec_path() { -#if defined(_WIN32) - wchar_t buf[32768] = { 0 }; // Large buffer to handle long paths - DWORD len = GetModuleFileNameW(nullptr, buf, _countof(buf)); - if (len == 0 || len >= _countof(buf)) { - throw std::runtime_error("GetModuleFileNameW failed or path too long"); - } - return std::filesystem::path(buf); -#elif defined(__APPLE__) && defined(__MACH__) - char small_path[PATH_MAX]; - uint32_t size = sizeof(small_path); - - if (_NSGetExecutablePath(small_path, &size) == 0) { - // resolve any symlinks to get absolute path - try { - return std::filesystem::canonical(std::filesystem::path(small_path)); - } catch (...) { - return std::filesystem::path(small_path); - } - } else { - // buffer was too small, allocate required size and call again - std::vector buf(size); - if (_NSGetExecutablePath(buf.data(), &size) == 0) { - try { - return std::filesystem::canonical(std::filesystem::path(buf.data())); - } catch (...) { - return std::filesystem::path(buf.data()); - } - } - throw std::runtime_error("_NSGetExecutablePath failed after buffer resize"); - } -#else - char path[FILENAME_MAX]; - ssize_t count = readlink("/proc/self/exe", path, FILENAME_MAX); - if (count <= 0) { - throw std::runtime_error("failed to resolve /proc/self/exe"); - } - return std::filesystem::path(std::string(path, count)); -#endif -} - -struct local_model { - std::string name; - std::string path; - std::string path_mmproj; -}; - -static std::vector list_local_models(const std::string & dir) { - if (!std::filesystem::exists(dir) || !std::filesystem::is_directory(dir)) { - throw std::runtime_error(string_format("error: '%s' does not exist or is not a directory\n", dir.c_str())); - } - - std::vector models; - auto scan_subdir = [&models](const std::string & subdir_path, const std::string & name) { - auto files = fs_list(subdir_path, false); - common_file_info model_file; - common_file_info first_shard_file; - common_file_info mmproj_file; - for (const auto & file : files) { - if (string_ends_with(file.name, ".gguf")) { - if (file.name.find("mmproj") != std::string::npos) { - mmproj_file = file; - } else if (file.name.find("-00001-of-") != std::string::npos) { - first_shard_file = file; - } else { - model_file = file; - } - } - } - // single file model - local_model model{ - /* name */ name, - /* path */ first_shard_file.path.empty() ? model_file.path : first_shard_file.path, - /* path_mmproj */ mmproj_file.path // can be empty - }; - if (!model.path.empty()) { - models.push_back(model); - } - }; - - auto files = fs_list(dir, true); - for (const auto & file : files) { - if (file.is_dir) { - scan_subdir(file.path, file.name); - } else if (string_ends_with(file.name, ".gguf")) { - // single file model - std::string name = file.name; - string_replace_all(name, ".gguf", ""); - local_model model{ - /* name */ name, - /* path */ file.path, - /* path_mmproj */ "" - }; - models.push_back(model); - } - } - return models; -} - -// -// server_presets -// - - -server_presets::server_presets(int argc, char ** argv, common_params & base_params, const std::string & presets_path) - : ctx_params(common_params_parser_init(base_params, LLAMA_EXAMPLE_SERVER)) { - if (!presets_path.empty()) { - presets = common_presets_load(presets_path, ctx_params); - SRV_INF("Loaded %zu presets from %s\n", presets.size(), presets_path.c_str()); - } - - // populate reserved args (will be appended by the router) - for (auto & opt : ctx_params.options) { - if (opt.env == nullptr) { - continue; - } - std::string env = opt.env; - if (env == "LLAMA_ARG_PORT" || - env == "LLAMA_ARG_HOST" || - env == "LLAMA_ARG_ALIAS" || - env == "LLAMA_ARG_API_KEY" || - env == "LLAMA_ARG_MODELS_DIR" || - env == "LLAMA_ARG_MODELS_MAX" || - env == "LLAMA_ARG_MODELS_PRESET" || - env == "LLAMA_ARG_MODEL" || - env == "LLAMA_ARG_MMPROJ" || - env == "LLAMA_ARG_HF_REPO" || - env == "LLAMA_ARG_NO_MODELS_AUTOLOAD") { - control_args[env] = opt; - } - } - - // read base args from router's argv - common_params_parse(argc, argv, LLAMA_EXAMPLE_SERVER, base_args); - - // remove any router-controlled args from base_args - for (const auto & cargs : control_args) { - auto it = base_args.find(cargs.second); - if (it != base_args.end()) { - base_args.erase(it); - } - } -} - -common_preset server_presets::get_preset(const std::string & name) { - auto it = presets.find(name); - if (it != presets.end()) { - return it->second; - } - return common_preset(); -} - -void server_presets::render_args(server_model_meta & meta) { - common_preset preset = meta.preset; // copy - // merging 3 kinds of args: - // 1. model-specific args (from preset) - // force removing control args if any - for (auto & cargs : control_args) { - if (preset.options.find(cargs.second) != preset.options.end()) { - SRV_WRN("Preset '%s' contains reserved arg '%s', removing it\n", preset.name.c_str(), cargs.second.args[0]); - preset.options.erase(cargs.second); - } - } - // 2. base args (from router) - // inherit from base args - for (const auto & [arg, value] : base_args) { - preset.options[arg] = value; - } - // 3. control args (from router) - // set control values - preset.options[control_args["LLAMA_ARG_HOST"]] = CHILD_ADDR; - preset.options[control_args["LLAMA_ARG_PORT"]] = std::to_string(meta.port); - preset.options[control_args["LLAMA_ARG_ALIAS"]] = meta.name; - if (meta.in_cache) { - preset.options[control_args["LLAMA_ARG_HF_REPO"]] = meta.name; - } else { - preset.options[control_args["LLAMA_ARG_MODEL"]] = meta.path; - if (!meta.path_mmproj.empty()) { - preset.options[control_args["LLAMA_ARG_MMPROJ"]] = meta.path_mmproj; - } - } - meta.args = preset.to_args(); - // add back the binary path at the front - meta.args.insert(meta.args.begin(), get_server_exec_path().string()); -} - -// -// server_models -// - -server_models::server_models( - const common_params & params, - int argc, - char ** argv, - char ** envp) : base_params(params), presets(argc, argv, base_params, params.models_preset) { - for (int i = 0; i < argc; i++) { - base_args.push_back(std::string(argv[i])); - } - for (char ** env = envp; *env != nullptr; env++) { - base_env.push_back(std::string(*env)); - } - GGML_ASSERT(!base_args.empty()); - // set binary path - try { - base_args[0] = get_server_exec_path().string(); - } catch (const std::exception & e) { - LOG_WRN("failed to get server executable path: %s\n", e.what()); - LOG_WRN("using original argv[0] as fallback: %s\n", base_args[0].c_str()); - } - load_models(); -} - -void server_models::add_model(server_model_meta && meta) { - if (mapping.find(meta.name) != mapping.end()) { - throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str())); - } - presets.render_args(meta); // populate meta.args - std::string name = meta.name; - mapping[name] = instance_t{ - /* subproc */ std::make_shared(), - /* th */ std::thread(), - /* meta */ std::move(meta) - }; -} - -static std::vector list_custom_path_models(server_presets & presets) { - // detect any custom-path models in presets - std::vector custom_models; - for (auto & [model_name, preset] : presets.presets) { - local_model model; - model.name = model_name; - std::vector to_erase; - for (auto & [arg, value] : preset.options) { - std::string env(arg.env ? arg.env : ""); - if (env == "LLAMA_ARG_MODEL") { - model.path = value; - to_erase.push_back(arg); - } - if (env == "LLAMA_ARG_MMPROJ") { - model.path_mmproj = value; - to_erase.push_back(arg); - } - } - for (auto & arg : to_erase) { - preset.options.erase(arg); - } - if (!model.name.empty() && !model.path.empty()) { - custom_models.push_back(model); - } - } - return custom_models; -} - -// TODO: allow refreshing cached model list -void server_models::load_models() { - // loading models from 3 sources: - // 1. cached models - auto cached_models = common_list_cached_models(); - for (const auto & model : cached_models) { - server_model_meta meta{ - /* preset */ presets.get_preset(model.to_string()), - /* name */ model.to_string(), - /* path */ model.manifest_path, - /* path_mmproj */ "", // auto-detected when loading - /* in_cache */ true, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* exit_code */ 0 - }; - add_model(std::move(meta)); - } - // 2. local models specificed via --models-dir - if (!base_params.models_dir.empty()) { - auto local_models = list_local_models(base_params.models_dir); - for (const auto & model : local_models) { - if (mapping.find(model.name) != mapping.end()) { - // already exists in cached models, skip - continue; - } - server_model_meta meta{ - /* preset */ presets.get_preset(model.name), - /* name */ model.name, - /* path */ model.path, - /* path_mmproj */ model.path_mmproj, - /* in_cache */ false, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* exit_code */ 0 - }; - add_model(std::move(meta)); - } - } - // 3. custom-path models specified in presets - auto custom_models = list_custom_path_models(presets); - for (const auto & model : custom_models) { - server_model_meta meta{ - /* preset */ presets.get_preset(model.name), - /* name */ model.name, - /* path */ model.path, - /* path_mmproj */ model.path_mmproj, - /* in_cache */ false, - /* port */ 0, - /* status */ SERVER_MODEL_STATUS_UNLOADED, - /* last_used */ 0, - /* args */ std::vector(), - /* exit_code */ 0 - }; - add_model(std::move(meta)); - } - // log available models - SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size()); - for (const auto & [name, inst] : mapping) { - SRV_INF(" %c %s\n", inst.meta.preset.name.empty() ? ' ' : '*', name.c_str()); - } -} - -void server_models::update_meta(const std::string & name, const server_model_meta & meta) { - std::lock_guard lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - it->second.meta = meta; - } - cv.notify_all(); // notify wait_until_loaded -} - -bool server_models::has_model(const std::string & name) { - std::lock_guard lk(mutex); - return mapping.find(name) != mapping.end(); -} - -std::optional server_models::get_meta(const std::string & name) { - std::lock_guard lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - return it->second.meta; - } - return std::nullopt; -} - -static int get_free_port() { -#ifdef _WIN32 - WSADATA wsaData; - if (WSAStartup(MAKEWORD(2, 2), &wsaData) != 0) { - return -1; - } - typedef SOCKET native_socket_t; -#define INVALID_SOCKET_VAL INVALID_SOCKET -#define CLOSE_SOCKET(s) closesocket(s) -#else - typedef int native_socket_t; -#define INVALID_SOCKET_VAL -1 -#define CLOSE_SOCKET(s) close(s) -#endif - - native_socket_t sock = socket(AF_INET, SOCK_STREAM, 0); - if (sock == INVALID_SOCKET_VAL) { -#ifdef _WIN32 - WSACleanup(); -#endif - return -1; - } - - struct sockaddr_in serv_addr; - std::memset(&serv_addr, 0, sizeof(serv_addr)); - serv_addr.sin_family = AF_INET; - serv_addr.sin_addr.s_addr = htonl(INADDR_ANY); - serv_addr.sin_port = htons(0); - - if (bind(sock, (struct sockaddr*)&serv_addr, sizeof(serv_addr)) != 0) { - CLOSE_SOCKET(sock); -#ifdef _WIN32 - WSACleanup(); -#endif - return -1; - } - -#ifdef _WIN32 - int namelen = sizeof(serv_addr); -#else - socklen_t namelen = sizeof(serv_addr); -#endif - if (getsockname(sock, (struct sockaddr*)&serv_addr, &namelen) != 0) { - CLOSE_SOCKET(sock); -#ifdef _WIN32 - WSACleanup(); -#endif - return -1; - } - - int port = ntohs(serv_addr.sin_port); - - CLOSE_SOCKET(sock); -#ifdef _WIN32 - WSACleanup(); -#endif - - return port; -} - -// helper to convert vector to char ** -// pointers are only valid as long as the original vector is valid -static std::vector to_char_ptr_array(const std::vector & vec) { - std::vector result; - result.reserve(vec.size() + 1); - for (const auto & s : vec) { - result.push_back(const_cast(s.c_str())); - } - result.push_back(nullptr); - return result; -} - -std::vector server_models::get_all_meta() { - std::lock_guard lk(mutex); - std::vector result; - result.reserve(mapping.size()); - for (const auto & [name, inst] : mapping) { - result.push_back(inst.meta); - } - return result; -} - -void server_models::unload_lru() { - if (base_params.models_max <= 0) { - return; // no limit - } - // remove one of the servers if we passed the models_max (least recently used - LRU) - std::string lru_model_name = ""; - int64_t lru_last_used = ggml_time_ms(); - size_t count_active = 0; - { - std::lock_guard lk(mutex); - for (const auto & m : mapping) { - if (m.second.meta.is_active()) { - count_active++; - if (m.second.meta.last_used < lru_last_used) { - lru_model_name = m.first; - lru_last_used = m.second.meta.last_used; - } - } - } - } - if (!lru_model_name.empty() && count_active >= (size_t)base_params.models_max) { - SRV_INF("models_max limit reached, removing LRU name=%s\n", lru_model_name.c_str()); - unload(lru_model_name); - } -} - -void server_models::load(const std::string & name) { - if (!has_model(name)) { - throw std::runtime_error("model name=" + name + " is not found"); - } - unload_lru(); - - std::lock_guard lk(mutex); - - auto meta = mapping[name].meta; - if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { - SRV_INF("model %s is not ready\n", name.c_str()); - return; - } - - // prepare new instance info - instance_t inst; - inst.meta = meta; - inst.meta.port = get_free_port(); - inst.meta.status = SERVER_MODEL_STATUS_LOADING; - inst.meta.last_used = ggml_time_ms(); - - if (inst.meta.port <= 0) { - throw std::runtime_error("failed to get a port number"); - } - - inst.subproc = std::make_shared(); - { - SRV_INF("spawning server instance with name=%s on port %d\n", inst.meta.name.c_str(), inst.meta.port); - - presets.render_args(inst.meta); // update meta.args - - std::vector child_args = inst.meta.args; // copy - std::vector child_env = base_env; // copy - child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); - - SRV_INF("%s", "spawning server instance with args:\n"); - for (const auto & arg : child_args) { - SRV_INF(" %s\n", arg.c_str()); - } - inst.meta.args = child_args; // save for debugging - - std::vector argv = to_char_ptr_array(child_args); - std::vector envp = to_char_ptr_array(child_env); - - int options = subprocess_option_no_window | subprocess_option_combined_stdout_stderr; - int result = subprocess_create_ex(argv.data(), options, envp.data(), inst.subproc.get()); - if (result != 0) { - throw std::runtime_error("failed to spawn server instance"); - } - - inst.stdin_file = subprocess_stdin(inst.subproc.get()); - } - - // start a thread to manage the child process - // captured variables are guaranteed to be destroyed only after the thread is joined - inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port]() { - // read stdout/stderr and forward to main server log - FILE * p_stdout_stderr = subprocess_stdout(child_proc.get()); - if (p_stdout_stderr) { - char buffer[4096]; - while (fgets(buffer, sizeof(buffer), p_stdout_stderr) != nullptr) { - LOG("[%5d] %s", port, buffer); - } - } else { - SRV_ERR("failed to get stdout/stderr of child process for name=%s\n", name.c_str()); - } - // we reach here when the child process exits - int exit_code = 0; - subprocess_join(child_proc.get(), &exit_code); - subprocess_destroy(child_proc.get()); - // update PID and status - { - std::lock_guard lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - auto & meta = it->second.meta; - meta.exit_code = exit_code; - meta.status = SERVER_MODEL_STATUS_UNLOADED; - } - cv.notify_all(); - } - SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); - }); - - // clean up old process/thread if exists - { - auto & old_instance = mapping[name]; - // old process should have exited already, but just in case, we clean it up here - if (subprocess_alive(old_instance.subproc.get())) { - SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); - subprocess_terminate(old_instance.subproc.get()); // force kill - } - if (old_instance.th.joinable()) { - old_instance.th.join(); - } - } - - mapping[name] = std::move(inst); - cv.notify_all(); -} - -static void interrupt_subprocess(FILE * stdin_file) { - // because subprocess.h does not provide a way to send SIGINT, - // we will send a command to the child process to exit gracefully - if (stdin_file) { - fprintf(stdin_file, "%s\n", CMD_EXIT); - fflush(stdin_file); - } -} - -void server_models::unload(const std::string & name) { - std::lock_guard lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - if (it->second.meta.is_active()) { - SRV_INF("unloading model instance name=%s\n", name.c_str()); - interrupt_subprocess(it->second.stdin_file); - // status change will be handled by the managing thread - } else { - SRV_WRN("model instance name=%s is not loaded\n", name.c_str()); - } - } -} - -void server_models::unload_all() { - std::vector to_join; - { - std::lock_guard lk(mutex); - for (auto & [name, inst] : mapping) { - if (inst.meta.is_active()) { - SRV_INF("unloading model instance name=%s\n", name.c_str()); - interrupt_subprocess(inst.stdin_file); - // status change will be handled by the managing thread - } - // moving the thread to join list to avoid deadlock - to_join.push_back(std::move(inst.th)); - } - } - for (auto & th : to_join) { - if (th.joinable()) { - th.join(); - } - } -} - -void server_models::update_status(const std::string & name, server_model_status status) { - // for now, we only allow updating to LOADED status - if (status != SERVER_MODEL_STATUS_LOADED) { - throw std::runtime_error("invalid status value"); - } - auto meta = get_meta(name); - if (meta.has_value()) { - meta->status = status; - update_meta(name, meta.value()); - } -} - -void server_models::wait_until_loaded(const std::string & name) { - std::unique_lock lk(mutex); - cv.wait(lk, [this, &name]() { - auto it = mapping.find(name); - if (it != mapping.end()) { - return it->second.meta.status != SERVER_MODEL_STATUS_LOADING; - } - return false; - }); -} - -bool server_models::ensure_model_loaded(const std::string & name) { - auto meta = get_meta(name); - if (!meta.has_value()) { - throw std::runtime_error("model name=" + name + " is not found"); - } - if (meta->status == SERVER_MODEL_STATUS_LOADED) { - return false; // already loaded - } - if (meta->status == SERVER_MODEL_STATUS_UNLOADED) { - SRV_INF("model name=%s is not loaded, loading...\n", name.c_str()); - load(name); - } - - SRV_INF("waiting until model name=%s is fully loaded...\n", name.c_str()); - wait_until_loaded(name); - - // check final status - meta = get_meta(name); - if (!meta.has_value() || meta->is_failed()) { - throw std::runtime_error("model name=" + name + " failed to load"); - } - - return true; -} - -server_http_res_ptr server_models::proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used) { - auto meta = get_meta(name); - if (!meta.has_value()) { - throw std::runtime_error("model name=" + name + " is not found"); - } - if (meta->status != SERVER_MODEL_STATUS_LOADED) { - throw std::invalid_argument("model name=" + name + " is not loaded"); - } - if (update_last_used) { - std::unique_lock lk(mutex); - mapping[name].meta.last_used = ggml_time_ms(); - } - SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); - auto proxy = std::make_unique( - method, - CHILD_ADDR, - meta->port, - req.path, - req.headers, - req.body, - req.should_stop); - return proxy; -} - -std::thread server_models::setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler) { - // send a notification to the router server that a model instance is ready - // TODO @ngxson : use HTTP client from libcommon - httplib::Client cli(base_params.hostname, router_port); - cli.set_connection_timeout(0, 200000); // 200 milliseconds - - httplib::Request req; - req.method = "POST"; - req.path = "/models/status"; - req.set_header("Content-Type", "application/json"); - if (!base_params.api_keys.empty()) { - req.set_header("Authorization", "Bearer " + base_params.api_keys[0]); - } - - json body; - body["model"] = name; - body["value"] = server_model_status_to_string(SERVER_MODEL_STATUS_LOADED); - req.body = body.dump(); - - SRV_INF("notifying router server (port=%d) that model %s is ready\n", router_port, name.c_str()); - auto result = cli.send(std::move(req)); - if (result.error() != httplib::Error::Success) { - auto err_str = httplib::to_string(result.error()); - SRV_ERR("failed to notify router server: %s\n", err_str.c_str()); - exit(1); // force exit - } - - // setup thread for monitoring stdin - return std::thread([shutdown_handler]() { - // wait for EOF on stdin - SRV_INF("%s", "child server monitoring thread started, waiting for EOF on stdin...\n"); - bool eof = false; - while (true) { - std::string line; - if (!std::getline(std::cin, line)) { - // EOF detected, that means the router server is unexpectedly exit or killed - eof = true; - break; - } - if (line.find(CMD_EXIT) != std::string::npos) { - SRV_INF("%s", "exit command received, exiting...\n"); - shutdown_handler(0); - break; - } - } - if (eof) { - SRV_INF("%s", "EOF on stdin detected, forcing shutdown...\n"); - exit(1); - } - }); -} - - - -// -// server_models_routes -// - -static void res_ok(std::unique_ptr & res, const json & response_data) { - res->status = 200; - res->data = safe_json_to_str(response_data); -} - -static void res_err(std::unique_ptr & res, const json & error_data) { - res->status = json_value(error_data, "code", 500); - res->data = safe_json_to_str({{ "error", error_data }}); -} - -static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr & res) { - if (name.empty()) { - res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); - return false; - } - auto meta = models.get_meta(name); - if (!meta.has_value()) { - res_err(res, format_error_response("model not found", ERROR_TYPE_INVALID_REQUEST)); - return false; - } - if (models_autoload) { - models.ensure_model_loaded(name); - } else { - if (meta->status != SERVER_MODEL_STATUS_LOADED) { - res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); - return false; - } - } - return true; -} - -static bool is_autoload(const common_params & params, const server_http_req & req) { - std::string autoload = req.get_param("autoload"); - if (autoload.empty()) { - return params.models_autoload; - } else { - return autoload == "true" || autoload == "1"; - } -} - -void server_models_routes::init_routes() { - this->get_router_props = [this](const server_http_req & req) { - std::string name = req.get_param("model"); - if (name.empty()) { - // main instance - auto res = std::make_unique(); - res_ok(res, { - // TODO: add support for this on web UI - {"role", "router"}, - {"max_instances", 4}, // dummy value for testing - // this is a dummy response to make sure webui doesn't break - {"model_alias", "llama-server"}, - {"model_path", "none"}, - {"default_generation_settings", { - {"params", json{}}, - {"n_ctx", 0}, - }}, - }); - return res; - } - return proxy_get(req); - }; - - this->proxy_get = [this](const server_http_req & req) { - std::string method = "GET"; - std::string name = req.get_param("model"); - bool autoload = is_autoload(params, req); - auto error_res = std::make_unique(); - if (!router_validate_model(name, models, autoload, error_res)) { - return error_res; - } - return models.proxy_request(req, method, name, false); - }; - - this->proxy_post = [this](const server_http_req & req) { - std::string method = "POST"; - json body = json::parse(req.body); - std::string name = json_value(body, "model", std::string()); - bool autoload = is_autoload(params, req); - auto error_res = std::make_unique(); - if (!router_validate_model(name, models, autoload, error_res)) { - return error_res; - } - return models.proxy_request(req, method, name, true); // update last usage for POST request only - }; - - this->post_router_models_load = [this](const server_http_req & req) { - auto res = std::make_unique(); - json body = json::parse(req.body); - std::string name = json_value(body, "model", std::string()); - auto model = models.get_meta(name); - if (!model.has_value()) { - res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); - return res; - } - if (model->status == SERVER_MODEL_STATUS_LOADED) { - res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - models.load(name); - res_ok(res, {{"success", true}}); - return res; - }; - - // used by child process to notify the router about status change - // TODO @ngxson : maybe implement authentication for this endpoint in the future - this->post_router_models_status = [this](const server_http_req & req) { - auto res = std::make_unique(); - json body = json::parse(req.body); - std::string model = json_value(body, "model", std::string()); - std::string value = json_value(body, "value", std::string()); - models.update_status(model, server_model_status_from_string(value)); - res_ok(res, {{"success", true}}); - return res; - }; - - this->get_router_models = [this](const server_http_req &) { - auto res = std::make_unique(); - json models_json = json::array(); - auto all_models = models.get_all_meta(); - std::time_t t = std::time(0); - for (const auto & meta : all_models) { - json status { - {"value", server_model_status_to_string(meta.status)}, - {"args", meta.args}, - }; - if (!meta.preset.name.empty()) { - status["preset"] = meta.preset.to_ini(); - } - if (meta.is_failed()) { - status["exit_code"] = meta.exit_code; - status["failed"] = true; - } - models_json.push_back(json { - {"id", meta.name}, - {"object", "model"}, // for OAI-compat - {"owned_by", "llamacpp"}, // for OAI-compat - {"created", t}, // for OAI-compat - {"in_cache", meta.in_cache}, - {"path", meta.path}, - {"status", status}, - // TODO: add other fields, may require reading GGUF metadata - }); - } - res_ok(res, { - {"data", models_json}, - {"object", "list"}, - }); - return res; - }; - - this->post_router_models_unload = [this](const server_http_req & req) { - auto res = std::make_unique(); - json body = json::parse(req.body); - std::string name = json_value(body, "model", std::string()); - auto model = models.get_meta(name); - if (!model.has_value()) { - res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - if (model->status != SERVER_MODEL_STATUS_LOADED) { - res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); - return res; - } - models.unload(name); - res_ok(res, {{"success", true}}); - return res; - }; -} - - - -// -// server_http_proxy -// - -// simple implementation of a pipe -// used for streaming data between threads -template -struct pipe_t { - std::mutex mutex; - std::condition_variable cv; - std::queue queue; - std::atomic writer_closed{false}; - std::atomic reader_closed{false}; - void close_write() { - writer_closed.store(true, std::memory_order_relaxed); - cv.notify_all(); - } - void close_read() { - reader_closed.store(true, std::memory_order_relaxed); - cv.notify_all(); - } - bool read(T & output, const std::function & should_stop) { - std::unique_lock lk(mutex); - constexpr auto poll_interval = std::chrono::milliseconds(500); - while (true) { - if (!queue.empty()) { - output = std::move(queue.front()); - queue.pop(); - return true; - } - if (writer_closed.load()) { - return false; // clean EOF - } - if (should_stop()) { - close_read(); // signal broken pipe to writer - return false; // cancelled / reader no longer alive - } - cv.wait_for(lk, poll_interval); - } - } - bool write(T && data) { - std::lock_guard lk(mutex); - if (reader_closed.load()) { - return false; // broken pipe - } - queue.push(std::move(data)); - cv.notify_one(); - return true; - } -}; - -static std::string to_lower_copy(const std::string & value) { - std::string lowered(value.size(), '\0'); - std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); }); - return lowered; -} - -static bool should_strip_proxy_header(const std::string & header_name) { - // Headers that get duplicated when router forwards child responses - if (header_name == "server" || - header_name == "transfer-encoding" || - header_name == "content-length" || // quick fix for https://github.com/ggml-org/llama.cpp/issues/17710 - header_name == "keep-alive") { - return true; - } - - // Router injects CORS, child also sends them: duplicate - if (header_name.rfind("access-control-", 0) == 0) { - return true; - } - - return false; -} - -server_http_proxy::server_http_proxy( - const std::string & method, - const std::string & host, - int port, - const std::string & path, - const std::map & headers, - const std::string & body, - const std::function should_stop) { - // shared between reader and writer threads - auto cli = std::make_shared(host, port); - auto pipe = std::make_shared>(); - - // setup Client - cli->set_connection_timeout(0, 200000); // 200 milliseconds - this->status = 500; // to be overwritten upon response - this->cleanup = [pipe]() { - pipe->close_read(); - pipe->close_write(); - }; - - // wire up the receive end of the pipe - this->next = [pipe, should_stop](std::string & out) -> bool { - msg_t msg; - bool has_next = pipe->read(msg, should_stop); - if (!msg.data.empty()) { - out = std::move(msg.data); - } - return has_next; // false if EOF or pipe broken - }; - - // wire up the HTTP client - // note: do NOT capture `this` pointer, as it may be destroyed before the thread ends - httplib::ResponseHandler response_handler = [pipe, cli](const httplib::Response & response) { - msg_t msg; - msg.status = response.status; - for (const auto & [key, value] : response.headers) { - const auto lowered = to_lower_copy(key); - if (should_strip_proxy_header(lowered)) { - continue; - } - if (lowered == "content-type") { - msg.content_type = value; - continue; - } - msg.headers[key] = value; - } - return pipe->write(std::move(msg)); // send headers first - }; - httplib::ContentReceiverWithProgress content_receiver = [pipe](const char * data, size_t data_length, size_t, size_t) { - // send data chunks - // returns false if pipe is closed / broken (signal to stop receiving) - return pipe->write({{}, 0, std::string(data, data_length), ""}); - }; - - // prepare the request to destination server - httplib::Request req; - { - req.method = method; - req.path = path; - for (const auto & [key, value] : headers) { - req.set_header(key, value); - } - req.body = body; - req.response_handler = response_handler; - req.content_receiver = content_receiver; - } - - // start the proxy thread - SRV_DBG("start proxy thread %s %s\n", req.method.c_str(), req.path.c_str()); - this->thread = std::thread([cli, pipe, req]() { - auto result = cli->send(std::move(req)); - if (result.error() != httplib::Error::Success) { - auto err_str = httplib::to_string(result.error()); - SRV_ERR("http client error: %s\n", err_str.c_str()); - pipe->write({{}, 500, "", ""}); // header - pipe->write({{}, 0, "proxy error: " + err_str, ""}); // body - } - pipe->close_write(); // signal EOF to reader - SRV_DBG("%s", "client request thread ended\n"); - }); - this->thread.detach(); - - // wait for the first chunk (headers) - { - msg_t header; - if (pipe->read(header, should_stop)) { - SRV_DBG("%s", "received response headers\n"); - this->status = header.status; - this->headers = std::move(header.headers); - if (!header.content_type.empty()) { - this->content_type = std::move(header.content_type); - } - } else { - SRV_DBG("%s", "no response headers received (request cancelled?)\n"); - } - } -} diff --git a/llamacpp/native/src/server/server-models.h b/llamacpp/native/src/server/server-models.h deleted file mode 100644 index 9cdbbad9b..000000000 --- a/llamacpp/native/src/server/server-models.h +++ /dev/null @@ -1,196 +0,0 @@ -#pragma once - -#include "common.h" -#include "preset.h" -#include "server-http.h" - -#include -#include -#include -#include - -/** - * state diagram: - * - * UNLOADED ──► LOADING ──► LOADED - * ▲ │ │ - * └───failed───┘ │ - * ▲ │ - * └────────unloaded─────────┘ - */ -enum server_model_status { - // TODO: also add downloading state when the logic is added - SERVER_MODEL_STATUS_UNLOADED, - SERVER_MODEL_STATUS_LOADING, - SERVER_MODEL_STATUS_LOADED -}; - -static server_model_status server_model_status_from_string(const std::string & status_str) { - if (status_str == "unloaded") { - return SERVER_MODEL_STATUS_UNLOADED; - } - if (status_str == "loading") { - return SERVER_MODEL_STATUS_LOADING; - } - if (status_str == "loaded") { - return SERVER_MODEL_STATUS_LOADED; - } - throw std::runtime_error("invalid server model status"); -} - -static std::string server_model_status_to_string(server_model_status status) { - switch (status) { - case SERVER_MODEL_STATUS_UNLOADED: return "unloaded"; - case SERVER_MODEL_STATUS_LOADING: return "loading"; - case SERVER_MODEL_STATUS_LOADED: return "loaded"; - default: return "unknown"; - } -} - -struct server_model_meta { - common_preset preset; - std::string name; - std::string path; - std::string path_mmproj; // only available if in_cache=false - bool in_cache = false; // if true, use -hf; use -m otherwise - int port = 0; - server_model_status status = SERVER_MODEL_STATUS_UNLOADED; - int64_t last_used = 0; // for LRU unloading - std::vector args; // args passed to the model instance, will be populated by render_args() - int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) - - bool is_active() const { - return status == SERVER_MODEL_STATUS_LOADED || status == SERVER_MODEL_STATUS_LOADING; - } - - bool is_failed() const { - return status == SERVER_MODEL_STATUS_UNLOADED && exit_code != 0; - } -}; - -// the server_presets struct holds the presets read from presets.ini -// as well as base args from the router server -struct server_presets { - common_presets presets; - common_params_context ctx_params; - std::map base_args; - std::map control_args; // args reserved for server control - - server_presets(int argc, char ** argv, common_params & base_params, const std::string & models_dir); - common_preset get_preset(const std::string & name); - void render_args(server_model_meta & meta); -}; - -struct subprocess_s; - -struct server_models { -private: - struct instance_t { - std::shared_ptr subproc; // shared between main thread and monitoring thread - std::thread th; - server_model_meta meta; - FILE * stdin_file = nullptr; - }; - - std::mutex mutex; - std::condition_variable cv; - std::map mapping; - - common_params base_params; - std::vector base_args; - std::vector base_env; - - server_presets presets; - - void update_meta(const std::string & name, const server_model_meta & meta); - - // unload least recently used models if the limit is reached - void unload_lru(); - - // not thread-safe, caller must hold mutex - void add_model(server_model_meta && meta); - -public: - server_models(const common_params & params, int argc, char ** argv, char ** envp); - - void load_models(); - - // check if a model instance exists - bool has_model(const std::string & name); - - // return a copy of model metadata - std::optional get_meta(const std::string & name); - - // return a copy of all model metadata - std::vector get_all_meta(); - - void load(const std::string & name); - void unload(const std::string & name); - void unload_all(); - - // update the status of a model instance - void update_status(const std::string & name, server_model_status status); - - // wait until the model instance is fully loaded - // return when the model is loaded or failed to load - void wait_until_loaded(const std::string & name); - - // load the model if not loaded, otherwise do nothing - // return false if model is already loaded; return true otherwise (meta may need to be refreshed) - bool ensure_model_loaded(const std::string & name); - - // proxy an HTTP request to the model instance - server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); - - // notify the router server that a model instance is ready - // return the monitoring thread (to be joined by the caller) - static std::thread setup_child_server(const common_params & base_params, int router_port, const std::string & name, std::function & shutdown_handler); -}; - -struct server_models_routes { - common_params params; - server_models models; - server_models_routes(const common_params & params, int argc, char ** argv, char ** envp) - : params(params), models(params, argc, argv, envp) { - init_routes(); - } - - void init_routes(); - // handlers using lambda function, so that they can capture `this` without `std::bind` - server_http_context::handler_t get_router_props; - server_http_context::handler_t proxy_get; - server_http_context::handler_t proxy_post; - server_http_context::handler_t get_router_models; - server_http_context::handler_t post_router_models_load; - server_http_context::handler_t post_router_models_status; - server_http_context::handler_t post_router_models_unload; -}; - -/** - * A simple HTTP proxy that forwards requests to another server - * and relays the responses back. - */ -struct server_http_proxy : server_http_res { - std::function cleanup = nullptr; -public: - server_http_proxy(const std::string & method, - const std::string & host, - int port, - const std::string & path, - const std::map & headers, - const std::string & body, - const std::function should_stop); - ~server_http_proxy() { - if (cleanup) { - cleanup(); - } - } -private: - std::thread thread; - struct msg_t { - std::map headers; - int status = 0; - std::string data; - std::string content_type; - }; -}; diff --git a/llamacpp/native/src/server/server-queue.cpp b/llamacpp/native/src/server/server-queue.cpp deleted file mode 100644 index 3cceb2bbe..000000000 --- a/llamacpp/native/src/server/server-queue.cpp +++ /dev/null @@ -1,370 +0,0 @@ -#include "server-task.h" -#include "server-queue.h" - -#include "log.h" - -#include - -#define QUE_INF(fmt, ...) LOG_INF("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_WRN(fmt, ...) LOG_WRN("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_ERR(fmt, ...) LOG_ERR("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define QUE_DBG(fmt, ...) LOG_DBG("que %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -#define RES_INF(fmt, ...) LOG_INF("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define RES_WRN(fmt, ...) LOG_WRN("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define RES_ERR(fmt, ...) LOG_ERR("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) -#define RES_DBG(fmt, ...) LOG_DBG("res %12.*s: " fmt, 12, __func__, __VA_ARGS__) - -// -// server_queue -// - -int server_queue::post(server_task && task, bool front) { - std::unique_lock lock(mutex_tasks); - GGML_ASSERT(task.id != -1); - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - const int task_id = task.id; - QUE_DBG("new task, id = %d, front = %d\n", task_id, front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - condition_tasks.notify_one(); - return task_id; -} - -int server_queue::post(std::vector && tasks, bool front) { - std::unique_lock lock(mutex_tasks); - for (auto & task : tasks) { - if (task.id == -1) { - task.id = id++; - } - // if this is cancel task make sure to clean up pending tasks - if (task.type == SERVER_TASK_TYPE_CANCEL) { - cleanup_pending_task(task.id_target); - } - QUE_DBG("new task, id = %d/%d, front = %d\n", task.id, (int) tasks.size(), front); - if (front) { - queue_tasks.push_front(std::move(task)); - } else { - queue_tasks.push_back(std::move(task)); - } - } - condition_tasks.notify_one(); - return 0; -} - -void server_queue::defer(server_task && task) { - std::unique_lock lock(mutex_tasks); - QUE_DBG("defer task, id = %d\n", task.id); - queue_tasks_deferred.push_back(std::move(task)); - condition_tasks.notify_one(); -} - -int server_queue::get_new_id() { - std::unique_lock lock(mutex_tasks); - int new_id = id++; - return new_id; -} - -void server_queue::on_new_task(std::function callback) { - callback_new_task = std::move(callback); -} - -void server_queue::on_update_slots(std::function callback) { - callback_update_slots = std::move(callback); -} - -void server_queue::pop_deferred_task() { - std::unique_lock lock(mutex_tasks); - if (!queue_tasks_deferred.empty()) { - queue_tasks.emplace_front(std::move(queue_tasks_deferred.front())); - queue_tasks_deferred.pop_front(); - } - condition_tasks.notify_one(); -} - -void server_queue::terminate() { - std::unique_lock lock(mutex_tasks); - running = false; - condition_tasks.notify_all(); -} - -void server_queue::start_loop() { - running = true; - - while (true) { - QUE_DBG("%s", "processing new tasks\n"); - - while (true) { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - lock.unlock(); - break; - } - server_task task = std::move(queue_tasks.front()); - queue_tasks.pop_front(); - lock.unlock(); - - QUE_DBG("processing task, id = %d\n", task.id); - callback_new_task(std::move(task)); - } - - // all tasks in the current loop is processed, slots data is now ready - QUE_DBG("%s", "update slots\n"); - - callback_update_slots(); - - QUE_DBG("%s", "waiting for new tasks\n"); - { - std::unique_lock lock(mutex_tasks); - if (!running) { - QUE_DBG("%s", "terminate\n"); - return; - } - if (queue_tasks.empty()) { - condition_tasks.wait(lock, [&]{ - return (!queue_tasks.empty() || !running); - }); - } - } - } -} - -void server_queue::cleanup_pending_task(int id_target) { - // no need lock because this is called exclusively by post() - auto rm_func = [id_target](const server_task & task) { - return task.id == id_target; - }; - queue_tasks.erase( - std::remove_if(queue_tasks.begin(), queue_tasks.end(), rm_func), - queue_tasks.end()); - queue_tasks_deferred.erase( - std::remove_if(queue_tasks_deferred.begin(), queue_tasks_deferred.end(), rm_func), - queue_tasks_deferred.end()); -} - -// -// server_response -// - -void server_response::add_waiting_task_id(int id_task) { - RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.insert(id_task); -} - -void server_response::add_waiting_tasks(const std::vector & tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & task : tasks) { - RES_DBG("add task %d to waiting list. current waiting = %d (before add)\n", task.id, (int) waiting_task_ids.size()); - waiting_task_ids.insert(task.id); - } -} - -void server_response::remove_waiting_task_id(int id_task) { - RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - - std::unique_lock lock(mutex_results); - waiting_task_ids.erase(id_task); - // make sure to clean up all pending results - queue_results.erase( - std::remove_if(queue_results.begin(), queue_results.end(), [id_task](const server_task_result_ptr & res) { - return res->id == id_task; - }), - queue_results.end()); -} - -void server_response::remove_waiting_task_ids(const std::unordered_set & id_tasks) { - std::unique_lock lock(mutex_results); - - for (const auto & id_task : id_tasks) { - RES_DBG("remove task %d from waiting list. current waiting = %d (before remove)\n", id_task, (int) waiting_task_ids.size()); - waiting_task_ids.erase(id_task); - } -} - -server_task_result_ptr server_response::recv(const std::unordered_set & id_tasks) { - while (true) { - std::unique_lock lock(mutex_results); - condition_results.wait(lock, [&]{ - if (!running) { - RES_DBG("%s : queue result stop\n", "recv"); - std::terminate(); // we cannot return here since the caller is HTTP code - } - return !queue_results.empty(); - }); - - for (size_t i = 0; i < queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - } - - // should never reach here -} - -server_task_result_ptr server_response::recv_with_timeout(const std::unordered_set & id_tasks, int timeout) { - while (true) { - std::unique_lock lock(mutex_results); - - for (int i = 0; i < (int) queue_results.size(); i++) { - if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) { - server_task_result_ptr res = std::move(queue_results[i]); - queue_results.erase(queue_results.begin() + i); - return res; - } - } - - std::cv_status cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout)); - if (!running) { - RES_DBG("%s : queue result stop\n", __func__); - std::terminate(); // we cannot return here since the caller is HTTP code - } - if (cr_res == std::cv_status::timeout) { - return nullptr; - } - } - - // should never reach here -} - -server_task_result_ptr server_response::recv(int id_task) { - std::unordered_set id_tasks = {id_task}; - return recv(id_tasks); -} - -void server_response::send(server_task_result_ptr && result) { - RES_DBG("sending result for task id = %d\n", result->id); - - std::unique_lock lock(mutex_results); - for (const auto & id_task : waiting_task_ids) { - if (result->id == id_task) { - RES_DBG("task id = %d pushed to result queue\n", result->id); - - queue_results.emplace_back(std::move(result)); - condition_results.notify_all(); - return; - } - } -} - -void server_response::terminate() { - running = false; - condition_results.notify_all(); -} - -// -// server_response_reader -// - -void server_response_reader::post_task(server_task && task) { - GGML_ASSERT(id_tasks.empty() && "post_task() can only be called once per reader"); - id_tasks.insert(task.id); - states.push_back(task.create_state()); - queue_results.add_waiting_task_id(task.id); - queue_tasks.post(std::move(task)); -} - -void server_response_reader::post_tasks(std::vector && tasks) { - GGML_ASSERT(id_tasks.empty() && "post_tasks() can only be called once per reader"); - id_tasks = server_task::get_list_id(tasks); - states.reserve(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - states.push_back(tasks[i].create_state()); - } - queue_results.add_waiting_tasks(tasks); - queue_tasks.post(std::move(tasks)); -} - -bool server_response_reader::has_next() const { - return !cancelled && received_count < id_tasks.size(); -} - -// return nullptr if should_stop() is true before receiving a result -// note: if one error is received, it will stop further processing and return error result -server_task_result_ptr server_response_reader::next(const std::function & should_stop) { - while (true) { - server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds); - if (result == nullptr) { - // timeout, check stop condition - if (should_stop()) { - SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n"); - return nullptr; - } - } else { - if (result->is_error()) { - stop(); // cancel remaining tasks - SRV_DBG("%s", "received error result, stopping further processing\n"); - return result; - } - if (!states.empty()) { - // update the generation state if needed - size_t idx = result->get_index(); - GGML_ASSERT(idx < states.size()); - result->update(states[idx]); - } - if (result->is_stop()) { - received_count++; - } - return result; - } - } - - // should not reach here -} - -server_response_reader::batch_response server_response_reader::wait_for_all(const std::function & should_stop) { - batch_response batch_res; - batch_res.results.resize(id_tasks.size()); - while (has_next()) { - auto res = next(should_stop); - if (res == nullptr) { - batch_res.is_terminated = true; - return batch_res; - } - if (res->is_error()) { - batch_res.error = std::move(res); - return batch_res; - } - const size_t idx = res->get_index(); - GGML_ASSERT(idx < batch_res.results.size() && "index out of range"); - GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received"); - batch_res.results[idx] = std::move(res); - } - return batch_res; -} - -void server_response_reader::stop() { - queue_results.remove_waiting_task_ids(id_tasks); - if (has_next() && !cancelled) { - // if tasks is not finished yet, cancel them - cancelled = true; - std::vector cancel_tasks; - cancel_tasks.reserve(id_tasks.size()); - for (const auto & id_task : id_tasks) { - SRV_WRN("cancel task, id_task = %d\n", id_task); - server_task task(SERVER_TASK_TYPE_CANCEL); - task.id_target = id_task; - queue_results.remove_waiting_task_id(id_task); - cancel_tasks.push_back(std::move(task)); - } - // push to beginning of the queue, so it has highest priority - queue_tasks.post(std::move(cancel_tasks), true); - } else { - SRV_DBG("%s", "all tasks already finished, no need to cancel\n"); - } -} diff --git a/llamacpp/native/src/server/server-queue.h b/llamacpp/native/src/server/server-queue.h deleted file mode 100644 index 8780d7fe1..000000000 --- a/llamacpp/native/src/server/server-queue.h +++ /dev/null @@ -1,158 +0,0 @@ -#pragma once - -#include "server-task.h" - -#include -#include -#include -#include - -// struct for managing server tasks -// in most cases, use server_response_reader to post new tasks and retrieve results -struct server_queue { -private: - int id = 0; - bool running; - - // queues - std::deque queue_tasks; - std::deque queue_tasks_deferred; - - std::mutex mutex_tasks; - std::condition_variable condition_tasks; - - // callback functions - std::function callback_new_task; - std::function callback_update_slots; - -public: - // Add a new task to the end of the queue - int post(server_task && task, bool front = false); - - // multi-task version of post() - int post(std::vector && tasks, bool front = false); - - // Add a new task, but defer until one slot is available - void defer(server_task && task); - - // Get the next id for creating a new task - int get_new_id(); - - // Register function to process a new task - void on_new_task(std::function callback); - - // Register the function to be called when all slots data is ready to be processed - void on_update_slots(std::function callback); - - // Call when the state of one slot is changed, it will move one task from deferred to main queue - void pop_deferred_task(); - - // end the start_loop routine - void terminate(); - - /** - * Main loop consists of these steps: - * - Wait until a new task arrives - * - Process the task (i.e. maybe copy data into slot) - * - Check if multitask is finished - * - Update all slots - */ - void start_loop(); - - // for metrics - size_t queue_tasks_deferred_size() { - std::unique_lock lock(mutex_tasks); - return queue_tasks_deferred.size(); - } - -private: - void cleanup_pending_task(int id_target); -}; - -// struct for managing server responses -// in most cases, use server_response_reader to retrieve results -struct server_response { -private: - bool running = true; - - // for keeping track of all tasks waiting for the result - std::unordered_set waiting_task_ids; - - // the main result queue (using ptr for polymorphism) - std::vector queue_results; - - std::mutex mutex_results; - std::condition_variable condition_results; - -public: - // add the id_task to the list of tasks waiting for response - void add_waiting_task_id(int id_task); - - void add_waiting_tasks(const std::vector & tasks); - - // when the request is finished, we can remove task associated with it - void remove_waiting_task_id(int id_task); - - // remove multiple tasks from waiting list - void remove_waiting_task_ids(const std::unordered_set & id_tasks); - - // This function blocks the thread until there is a response for one of the id_tasks - server_task_result_ptr recv(const std::unordered_set & id_tasks); - - // same as recv(), but have timeout in seconds - // if timeout is reached, nullptr is returned - server_task_result_ptr recv_with_timeout(const std::unordered_set & id_tasks, int timeout); - - // single-task version of recv() - server_task_result_ptr recv(int id_task); - - // Send a new result to a waiting id_task - void send(server_task_result_ptr && result); - - // terminate the waiting loop - void terminate(); -}; - -// utility class to make working with server_queue and server_response easier -// it provides a generator-like API for server responses -// support pooling connection state and aggregating multiple results -struct server_response_reader { - std::unordered_set id_tasks; - server_queue & queue_tasks; - server_response & queue_results; - size_t received_count = 0; - bool cancelled = false; - int polling_interval_seconds; - - // tracking generation state and partial tool calls - // only used by streaming completions - std::vector states; - - // should_stop function will be called each polling_interval_seconds - server_response_reader(server_queue & queue_tasks, server_response & queue_results, int polling_interval_seconds) - : queue_tasks(queue_tasks), queue_results(queue_results), polling_interval_seconds(polling_interval_seconds) {} - ~server_response_reader() { - stop(); - } - - int get_new_id() { - return queue_tasks.get_new_id(); - } - void post_task(server_task && task); - void post_tasks(std::vector && tasks); - bool has_next() const; - - // return nullptr if should_stop() is true before receiving a result - // note: if one error is received, it will stop further processing and return error result - server_task_result_ptr next(const std::function & should_stop); - - struct batch_response { - bool is_terminated = false; // if true, indicates that processing was stopped before all results were received - std::vector results; - server_task_result_ptr error; // nullptr if no error - }; - // aggregate multiple results - batch_response wait_for_all(const std::function & should_stop); - - void stop(); -}; diff --git a/llamacpp/native/src/server/server-task.cpp b/llamacpp/native/src/server/server-task.cpp deleted file mode 100644 index 360826062..000000000 --- a/llamacpp/native/src/server/server-task.cpp +++ /dev/null @@ -1,1502 +0,0 @@ -#include "server-common.h" -#include "server-task.h" - -#include "common.h" -#include "llama.h" -#include "chat.h" -#include "sampling.h" -#include "json-schema-to-grammar.h" - -using json = nlohmann::ordered_json; - -// -// task_params -// - -json task_params::format_logit_bias(const std::vector & logit_bias) const { - json data = json::array(); - for (const auto & lb : logit_bias) { - data.push_back(json{ - {"bias", lb.bias}, - {"token", lb.token}, - }); - } - return data; -} - -json task_params::to_json(bool only_metrics) const { - std::vector samplers; - samplers.reserve(sampling.samplers.size()); - for (const auto & sampler : sampling.samplers) { - samplers.emplace_back(common_sampler_type_to_str(sampler)); - } - - json lora = json::array(); - for (size_t i = 0; i < this->lora.size(); ++i) { - lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); - } - - if (only_metrics) { - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; - } - - auto grammar_triggers = json::array(); - for (const auto & trigger : sampling.grammar_triggers) { - server_grammar_trigger ct(trigger); - grammar_triggers.push_back(ct.to_json()); - } - - return json { - {"seed", sampling.seed}, - {"temperature", sampling.temp}, - {"dynatemp_range", sampling.dynatemp_range}, - {"dynatemp_exponent", sampling.dynatemp_exponent}, - {"top_k", sampling.top_k}, - {"top_p", sampling.top_p}, - {"min_p", sampling.min_p}, - {"top_n_sigma", sampling.top_n_sigma}, - {"xtc_probability", sampling.xtc_probability}, - {"xtc_threshold", sampling.xtc_threshold}, - {"typical_p", sampling.typ_p}, - {"repeat_last_n", sampling.penalty_last_n}, - {"repeat_penalty", sampling.penalty_repeat}, - {"presence_penalty", sampling.penalty_present}, - {"frequency_penalty", sampling.penalty_freq}, - {"dry_multiplier", sampling.dry_multiplier}, - {"dry_base", sampling.dry_base}, - {"dry_allowed_length", sampling.dry_allowed_length}, - {"dry_penalty_last_n", sampling.dry_penalty_last_n}, - {"dry_sequence_breakers", sampling.dry_sequence_breakers}, - {"mirostat", sampling.mirostat}, - {"mirostat_tau", sampling.mirostat_tau}, - {"mirostat_eta", sampling.mirostat_eta}, - {"stop", antiprompt}, - {"max_tokens", n_predict}, - {"n_predict", n_predict}, // TODO: deduplicate? - {"n_keep", n_keep}, - {"n_discard", n_discard}, - {"ignore_eos", sampling.ignore_eos}, - {"stream", stream}, - {"logit_bias", format_logit_bias(sampling.logit_bias)}, - {"n_probs", sampling.n_probs}, - {"min_keep", sampling.min_keep}, - {"grammar", sampling.grammar}, - {"grammar_lazy", sampling.grammar_lazy}, - {"grammar_triggers", grammar_triggers}, - {"preserved_tokens", sampling.preserved_tokens}, - {"chat_format", common_chat_format_name(oaicompat_chat_syntax.format)}, - {"reasoning_format", common_reasoning_format_name(oaicompat_chat_syntax.reasoning_format)}, - {"reasoning_in_content", oaicompat_chat_syntax.reasoning_in_content}, - {"thinking_forced_open", oaicompat_chat_syntax.thinking_forced_open}, - {"samplers", samplers}, - {"speculative.n_max", speculative.n_max}, - {"speculative.n_min", speculative.n_min}, - {"speculative.p_min", speculative.p_min}, - {"timings_per_token", timings_per_token}, - {"post_sampling_probs", post_sampling_probs}, - {"lora", lora}, - }; -} - -// -// server_task -// - -task_params server_task::params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data) { - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); - - task_params params; - - // Sampling parameter defaults are loaded from the global server context (but individual requests can still them) - task_params defaults; - defaults.sampling = params_base.sampling; - defaults.speculative = params_base.speculative; - defaults.n_keep = params_base.n_keep; - defaults.n_predict = params_base.n_predict; - defaults.n_cache_reuse = params_base.n_cache_reuse; - defaults.antiprompt = params_base.antiprompt; - - // enabling this will output extra debug information in the HTTP responses from the server - params.verbose = params_base.verbosity > 9; - params.timings_per_token = json_value(data, "timings_per_token", false); - - params.stream = json_value(data, "stream", false); - auto stream_opt = json_value(data, "stream_options", json::object()); - params.include_usage = json_value(stream_opt, "include_usage", false); - params.cache_prompt = json_value(data, "cache_prompt", true); - params.return_tokens = json_value(data, "return_tokens", false); - params.return_progress = json_value(data, "return_progress", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); - params.n_indent = json_value(data, "n_indent", defaults.n_indent); - params.n_keep = json_value(data, "n_keep", defaults.n_keep); - params.n_discard = json_value(data, "n_discard", defaults.n_discard); - params.n_cmpl = json_value(data, "n_cmpl", json_value(data, "n", 1)); - params.n_cache_reuse = json_value(data, "n_cache_reuse", defaults.n_cache_reuse); - //params.t_max_prompt_ms = json_value(data, "t_max_prompt_ms", defaults.t_max_prompt_ms); // TODO: implement - params.t_max_predict_ms = json_value(data, "t_max_predict_ms", defaults.t_max_predict_ms); - params.response_fields = json_value(data, "response_fields", std::vector()); - - params.sampling.top_k = json_value(data, "top_k", defaults.sampling.top_k); - params.sampling.top_p = json_value(data, "top_p", defaults.sampling.top_p); - params.sampling.min_p = json_value(data, "min_p", defaults.sampling.min_p); - params.sampling.top_n_sigma = json_value(data, "top_n_sigma", defaults.sampling.top_n_sigma); - params.sampling.xtc_probability = json_value(data, "xtc_probability", defaults.sampling.xtc_probability); - params.sampling.xtc_threshold = json_value(data, "xtc_threshold", defaults.sampling.xtc_threshold); - params.sampling.typ_p = json_value(data, "typical_p", defaults.sampling.typ_p); - params.sampling.temp = json_value(data, "temperature", defaults.sampling.temp); - params.sampling.dynatemp_range = json_value(data, "dynatemp_range", defaults.sampling.dynatemp_range); - params.sampling.dynatemp_exponent = json_value(data, "dynatemp_exponent", defaults.sampling.dynatemp_exponent); - params.sampling.penalty_last_n = json_value(data, "repeat_last_n", defaults.sampling.penalty_last_n); - params.sampling.penalty_repeat = json_value(data, "repeat_penalty", defaults.sampling.penalty_repeat); - params.sampling.penalty_freq = json_value(data, "frequency_penalty", defaults.sampling.penalty_freq); - params.sampling.penalty_present = json_value(data, "presence_penalty", defaults.sampling.penalty_present); - params.sampling.dry_multiplier = json_value(data, "dry_multiplier", defaults.sampling.dry_multiplier); - params.sampling.dry_base = json_value(data, "dry_base", defaults.sampling.dry_base); - params.sampling.dry_allowed_length = json_value(data, "dry_allowed_length", defaults.sampling.dry_allowed_length); - params.sampling.dry_penalty_last_n = json_value(data, "dry_penalty_last_n", defaults.sampling.dry_penalty_last_n); - params.sampling.mirostat = json_value(data, "mirostat", defaults.sampling.mirostat); - params.sampling.mirostat_tau = json_value(data, "mirostat_tau", defaults.sampling.mirostat_tau); - params.sampling.mirostat_eta = json_value(data, "mirostat_eta", defaults.sampling.mirostat_eta); - params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); - params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); - params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); - params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); - - params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); - params.speculative.n_max = json_value(data, "speculative.n_max", defaults.speculative.n_max); - params.speculative.p_min = json_value(data, "speculative.p_min", defaults.speculative.p_min); - - params.speculative.n_min = std::min(params.speculative.n_max, params.speculative.n_min); - params.speculative.n_min = std::max(params.speculative.n_min, 0); - params.speculative.n_max = std::max(params.speculative.n_max, 0); - - // Use OpenAI API logprobs only if n_probs wasn't provided - if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ - params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs); - } - - if (data.contains("lora")) { - if (data.at("lora").is_array()) { - params.lora = parse_lora_request(params_base.lora_adapters, data.at("lora")); - } else { - throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); - } - } else { - params.lora = params_base.lora_adapters; - } - - // TODO: add more sanity checks for the input parameters - - if (params.sampling.penalty_last_n < -1) { - throw std::runtime_error("Error: repeat_last_n must be >= -1"); - } - - if (params.sampling.dry_penalty_last_n < -1) { - throw std::runtime_error("Error: dry_penalty_last_n must be >= -1"); - } - - if (params.sampling.penalty_last_n == -1) { - // note: should be the slot's context and not the full context, but it's ok - params.sampling.penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_penalty_last_n == -1) { - params.sampling.dry_penalty_last_n = llama_n_ctx(ctx); - } - - if (params.sampling.dry_base < 1.0f) { - params.sampling.dry_base = defaults.sampling.dry_base; - } - - // sequence breakers for DRY - { - // Currently, this is not compatible with TextGen WebUI, Koboldcpp and SillyTavern format - // Ref: https://github.com/oobabooga/text-generation-webui/blob/d1af7a41ade7bd3c3a463bfa640725edb818ebaf/extensions/openai/typing.py#L39 - - if (data.contains("dry_sequence_breakers")) { - params.sampling.dry_sequence_breakers = json_value(data, "dry_sequence_breakers", std::vector()); - if (params.sampling.dry_sequence_breakers.empty()) { - throw std::runtime_error("Error: dry_sequence_breakers must be a non-empty array of strings"); - } - } - } - - // process "json_schema" and "grammar" - if (data.contains("json_schema") && !data.contains("grammar")) { - try { - auto schema = json_value(data, "json_schema", json::object()); - SRV_DBG("JSON schema: %s\n", schema.dump(2).c_str()); - params.sampling.grammar = json_schema_to_grammar(schema); - SRV_DBG("Converted grammar: %s\n", params.sampling.grammar.c_str()); - } catch (const std::exception & e) { - throw std::runtime_error(std::string("\"json_schema\": ") + e.what()); - } - } else { - params.sampling.grammar = json_value(data, "grammar", defaults.sampling.grammar); - SRV_DBG("Grammar: %s\n", params.sampling.grammar.c_str()); - params.sampling.grammar_lazy = json_value(data, "grammar_lazy", defaults.sampling.grammar_lazy); - SRV_DBG("Grammar lazy: %s\n", params.sampling.grammar_lazy ? "true" : "false"); - } - - { - auto it = data.find("chat_format"); - if (it != data.end()) { - params.oaicompat_chat_syntax.format = static_cast(it->get()); - SRV_INF("Chat format: %s\n", common_chat_format_name(params.oaicompat_chat_syntax.format)); - } else { - params.oaicompat_chat_syntax.format = defaults.oaicompat_chat_syntax.format; - } - common_reasoning_format reasoning_format = params_base.reasoning_format; - if (data.contains("reasoning_format")) { - reasoning_format = common_reasoning_format_from_name(data.at("reasoning_format").get()); - } - params.oaicompat_chat_syntax.reasoning_format = reasoning_format; - params.oaicompat_chat_syntax.reasoning_in_content = params.stream && (reasoning_format == COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY); - params.oaicompat_chat_syntax.thinking_forced_open = json_value(data, "thinking_forced_open", false); - params.oaicompat_chat_syntax.parse_tool_calls = json_value(data, "parse_tool_calls", false); - if (data.contains("chat_parser")) { - params.oaicompat_chat_syntax.parser.load(data.at("chat_parser").get()); - } - } - - { - const auto preserved_tokens = data.find("preserved_tokens"); - if (preserved_tokens != data.end()) { - for (const auto & t : *preserved_tokens) { - auto ids = common_tokenize(vocab, t.get(), /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - SRV_DBG("Preserved token: %d\n", ids[0]); - params.sampling.preserved_tokens.insert(ids[0]); - } else { - // This may happen when using a tool call style meant for a model with special tokens to preserve on a model without said tokens. - SRV_DBG("Not preserved because more than 1 token: %s\n", t.get().c_str()); - } - } - } - const auto grammar_triggers = data.find("grammar_triggers"); - if (grammar_triggers != data.end()) { - for (const auto & t : *grammar_triggers) { - server_grammar_trigger ct(t); - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_WORD) { - const auto & word = ct.value.value; - auto ids = common_tokenize(vocab, word, /* add_special= */ false, /* parse_special= */ true); - if (ids.size() == 1) { - auto token = ids[0]; - if (std::find(params.sampling.preserved_tokens.begin(), params.sampling.preserved_tokens.end(), (llama_token) token) == params.sampling.preserved_tokens.end()) { - throw std::runtime_error("Grammar trigger word should be marked as preserved token: " + word); - } - SRV_DBG("Grammar trigger token: %d (`%s`)\n", token, word.c_str()); - common_grammar_trigger trigger; - trigger.type = COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN; - trigger.value = word; - trigger.token = token; - params.sampling.grammar_triggers.push_back(std::move(trigger)); - } else { - SRV_DBG("Grammar trigger word: `%s`\n", word.c_str()); - params.sampling.grammar_triggers.push_back({COMMON_GRAMMAR_TRIGGER_TYPE_WORD, word}); - } - } else { - if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN) { - SRV_DBG("Grammar trigger pattern: `%s`\n", ct.value.value.c_str()); - } else if (ct.value.type == COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL) { - SRV_DBG("Grammar trigger pattern full: `%s`\n", ct.value.value.c_str()); - } else { - throw std::runtime_error("Unknown grammar trigger type"); - } - params.sampling.grammar_triggers.emplace_back(std::move(ct.value)); - } - } - } - if (params.sampling.grammar_lazy && params.sampling.grammar_triggers.empty()) { - throw std::runtime_error("Error: no triggers set for lazy grammar!"); - } - } - - { - params.sampling.logit_bias.clear(); - - const auto & logit_bias = data.find("logit_bias"); - if (logit_bias != data.end() && logit_bias->is_array()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : *logit_bias) { - // TODO: we may want to throw errors here, in case "el" is incorrect - if (el.is_array() && el.size() == 2) { - float bias; - if (el[1].is_number()) { - bias = el[1].get(); - } else if (el[1].is_boolean() && !el[1].get()) { - bias = -INFINITY; - } else { - continue; - } - - if (el[0].is_number_integer()) { - llama_token tok = el[0].get(); - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else if (el[0].is_string()) { - auto toks = common_tokenize(vocab, el[0].get(), false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - } else if (logit_bias != data.end() && logit_bias->is_object()) { - const int n_vocab = llama_vocab_n_tokens(vocab); - for (const auto & el : logit_bias->items()) { - float bias; - const auto & key = el.key(); - const auto & value = el.value(); - if (value.is_number()) { - bias = value.get(); - } else if (value.is_boolean() && !value.get()) { - bias = -INFINITY; - } else { - continue; - } - - char *end; - llama_token tok = strtol(key.c_str(), &end, 10); - if (*end == 0) { - if (tok >= 0 && tok < n_vocab) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } else { - auto toks = common_tokenize(vocab, key, false); - for (auto tok : toks) { - params.sampling.logit_bias.push_back({tok, bias}); - } - } - } - } - - params.sampling.ignore_eos = json_value(data, "ignore_eos", params_base.sampling.ignore_eos); - if (params.sampling.ignore_eos) { - params.sampling.logit_bias.insert( - params.sampling.logit_bias.end(), - defaults.sampling.logit_bias_eog.begin(), defaults.sampling.logit_bias_eog.end()); - } - } - - { - params.antiprompt.clear(); - - const auto & stop = data.find("stop"); - if (stop != data.end() && stop->is_array()) { - for (const auto & word : *stop) { - if (!word.empty()) { - params.antiprompt.push_back(word); - } - } - } - // set reverse prompt from cli args if not set in the request - if (params.antiprompt.empty()) { - params.antiprompt = defaults.antiprompt; - } - } - - { - const auto samplers = data.find("samplers"); - if (samplers != data.end()) { - if (samplers->is_array()) { - params.sampling.samplers = common_sampler_types_from_names(*samplers, false); - } else if (samplers->is_string()){ - params.sampling.samplers = common_sampler_types_from_chars(samplers->get()); - } - } else { - params.sampling.samplers = defaults.sampling.samplers; - } - } - - if (params.n_cmpl > params_base.n_parallel) { - throw std::runtime_error("n_cmpl cannot be greater than the number of slots, please increase -np"); - } - - return params; -} - -// -// result_timings -// - -json result_timings::to_json() const { - json base = { - {"cache_n", cache_n}, - - {"prompt_n", prompt_n}, - {"prompt_ms", prompt_ms}, - {"prompt_per_token_ms", prompt_per_token_ms}, - {"prompt_per_second", prompt_per_second}, - - {"predicted_n", predicted_n}, - {"predicted_ms", predicted_ms}, - {"predicted_per_token_ms", predicted_per_token_ms}, - {"predicted_per_second", predicted_per_second}, - }; - - if (draft_n > 0) { - base["draft_n"] = draft_n; - base["draft_n_accepted"] = draft_n_accepted; - } - - return base; -} - -// -// result_prompt_progress -// -json result_prompt_progress::to_json() const { - return json { - {"total", total}, - {"cache", cache}, - {"processed", processed}, - {"time_ms", time_ms}, - }; -} - -static inline std::string stop_type_to_str(stop_type type) { - switch (type) { - case STOP_TYPE_EOS: return "eos"; - case STOP_TYPE_WORD: return "word"; - case STOP_TYPE_LIMIT: return "limit"; - default: return "none"; - } -} - -// -// completion_token_output -// - -json completion_token_output::to_json(bool post_sampling_probs) const { - json probs_for_token = json::array(); - for (const auto & p : probs) { - std::string txt(p.txt); - txt.resize(validate_utf8(txt)); - probs_for_token.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.txt)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - }); - } - return probs_for_token; -} - -json completion_token_output::probs_vector_to_json(const std::vector & probs, bool post_sampling_probs) { - json out = json::array(); - for (const auto & p : probs) { - std::string txt(p.text_to_send); - txt.resize(validate_utf8(txt)); - out.push_back(json { - {"id", p.tok}, - {"token", txt}, - {"bytes", str_to_bytes(p.text_to_send)}, - { - post_sampling_probs ? "prob" : "logprob", - post_sampling_probs ? p.prob : logarithm(p.prob) - }, - { - post_sampling_probs ? "top_probs" : "top_logprobs", - p.to_json(post_sampling_probs) - }, - }); - } - return out; -} - -float completion_token_output::logarithm(float x) { - // nlohmann::json converts -inf to null, so we need to prevent that - return x == 0.0f ? std::numeric_limits::lowest() : std::log(x); -} - -std::vector completion_token_output::str_to_bytes(const std::string & str) { - std::vector bytes; - for (unsigned char c : str) { - bytes.push_back(c); - } - return bytes; -} - -// -// server_task_result_cmpl_final -// -json server_task_result_cmpl_final::to_json() { - GGML_ASSERT(is_updated && "update() must be called before to_json()"); - switch (res_type) { - case TASK_RESPONSE_TYPE_NONE: - return to_json_non_oaicompat(); - case TASK_RESPONSE_TYPE_OAI_CMPL: - return to_json_oaicompat(); - case TASK_RESPONSE_TYPE_OAI_CHAT: - return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat(); - case TASK_RESPONSE_TYPE_ANTHROPIC: - return stream ? to_json_anthropic_stream() : to_json_anthropic(); - default: - GGML_ASSERT(false && "Invalid task_response_type"); - } -} - -json server_task_result_cmpl_final::to_json_non_oaicompat() { - json res = json { - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"id_slot", id_slot}, - {"stop", true}, - {"model", oaicompat_model}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - {"generation_settings", generation_params.to_json()}, - {"prompt", prompt}, - {"has_new_line", has_new_line}, - {"truncated", truncated}, - {"stop_type", stop_type_to_str(stop)}, - {"stopping_word", stopping_word}, - {"tokens_cached", n_tokens_cached}, - {"timings", timings.to_json()}, - }; - if (!stream && !probs_output.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs); - } - return response_fields.empty() ? res : json_get_nested_values(response_fields, res); -} - -json server_task_result_cmpl_final::to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (!stream && probs_output.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - json finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = "stop"; - } - json res = json { - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", finish_reason}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; -} - -json server_task_result_cmpl_final::to_json_oaicompat_chat() { - std::string finish_reason = "length"; - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json choice { - {"finish_reason", finish_reason}, - {"index", index}, - {"message", msg.to_json_oaicompat()}, - }; - - if (!stream && probs_output.size() > 0) { - choice["logprobs"] = json{ - {"content", completion_token_output::probs_vector_to_json(probs_output, post_sampling_probs)}, - }; - } - - std::time_t t = std::time(0); - - json res = json { - {"choices", json::array({choice})}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens} - }}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - - return res; -} - -common_chat_msg task_result_state::update_chat_msg( - const std::string & text_added, - bool is_partial, - std::vector & diffs) { - generated_text += text_added; - auto msg_prv_copy = chat_msg; - SRV_DBG("Parsing chat message: %s\n", generated_text.c_str()); - auto new_msg = common_chat_parse( - generated_text, - is_partial, - oaicompat_chat_syntax); - if (!new_msg.empty()) { - new_msg.set_tool_call_ids(generated_tool_call_ids, gen_tool_call_id); - chat_msg = new_msg; - diffs = common_chat_msg_diff::compute_diffs(msg_prv_copy, new_msg.empty() ? msg_prv_copy : new_msg); - } - return chat_msg; -} - -json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() { - std::time_t t = std::time(0); - std::string finish_reason = "length"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - finish_reason = oaicompat_msg.tool_calls.empty() ? "stop" : "tool_calls"; - } - - json deltas = json::array(); - for (const auto & diff : oaicompat_msg_diffs) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", 0}, - {"delta", common_chat_msg_diff_to_json_oaicompat(diff)}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - } - - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", finish_reason}, - {"index", 0}, - {"delta", json::object()}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - - if (include_usage) { - // OpenAI API spec for chat.completion.chunks specifies an empty `choices` array for the last chunk when including usage - // https://platform.openai.com/docs/api-reference/chat_streaming/streaming#chat_streaming/streaming-choices - deltas.push_back({ - {"choices", json::array()}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - {"usage", json { - {"completion_tokens", n_decoded}, - {"prompt_tokens", n_prompt_tokens}, - {"total_tokens", n_decoded + n_prompt_tokens}, - }}, - }); - } - - if (timings.prompt_n >= 0) { - deltas.back().push_back({"timings", timings.to_json()}); - } - - // extra fields for debugging purposes - if (verbose && !deltas.empty()) { - deltas.front()["__verbose"] = to_json_non_oaicompat(); - } - - return deltas; -} - -json server_task_result_cmpl_final::to_json_anthropic() { - std::string stop_reason = "max_tokens"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; - } - - json content_blocks = json::array(); - - common_chat_msg msg; - if (!oaicompat_msg.empty()) { - msg = oaicompat_msg; - } else { - msg.role = "assistant"; - msg.content = content; - } - - if (!msg.content.empty()) { - content_blocks.push_back({ - {"type", "text"}, - {"text", msg.content} - }); - } - - for (const auto & tool_call : msg.tool_calls) { - json tool_use_block = { - {"type", "tool_use"}, - {"id", tool_call.id}, - {"name", tool_call.name} - }; - - try { - tool_use_block["input"] = json::parse(tool_call.arguments); - } catch (const std::exception &) { - tool_use_block["input"] = json::object(); - } - - content_blocks.push_back(tool_use_block); - } - - json res = { - {"id", oaicompat_cmpl_id}, - {"type", "message"}, - {"role", "assistant"}, - {"content", content_blocks}, - {"model", oaicompat_model}, - {"stop_reason", stop_reason}, - {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}, - {"usage", { - {"input_tokens", n_prompt_tokens}, - {"output_tokens", n_decoded} - }} - }; - - return res; -} - -json server_task_result_cmpl_final::to_json_anthropic_stream() { - json events = json::array(); - - std::string stop_reason = "max_tokens"; - if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) { - stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; - } - - bool has_text = !oaicompat_msg.content.empty(); - size_t num_tool_calls = oaicompat_msg.tool_calls.size(); - - bool text_block_started = false; - std::unordered_set tool_calls_started; - - for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (!text_block_started) { - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", 0}, - {"content_block", { - {"type", "text"}, - {"text", ""} - }} - }} - }); - text_block_started = true; - } - - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", 0}, - {"delta", { - {"type", "text_delta"}, - {"text", diff.content_delta} - }} - }} - }); - } - - if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; - - if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { - const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; - - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", content_block_index}, - {"content_block", { - {"type", "tool_use"}, - {"id", full_tool_call.id}, - {"name", full_tool_call.name} - }} - }} - }); - tool_calls_started.insert(diff.tool_call_index); - } - - if (!diff.tool_call_delta.arguments.empty()) { - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", content_block_index}, - {"delta", { - {"type", "input_json_delta"}, - {"partial_json", diff.tool_call_delta.arguments} - }} - }} - }); - } - } - } - - if (has_text) { - events.push_back({ - {"event", "content_block_stop"}, - {"data", { - {"type", "content_block_stop"}, - {"index", 0} - }} - }); - } - - for (size_t i = 0; i < num_tool_calls; i++) { - size_t content_block_index = (has_text ? 1 : 0) + i; - events.push_back({ - {"event", "content_block_stop"}, - {"data", { - {"type", "content_block_stop"}, - {"index", content_block_index} - }} - }); - } - - events.push_back({ - {"event", "message_delta"}, - {"data", { - {"type", "message_delta"}, - {"delta", { - {"stop_reason", stop_reason}, - {"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)} - }}, - {"usage", { - {"output_tokens", n_decoded} - }} - }} - }); - - events.push_back({ - {"event", "message_stop"}, - {"data", { - {"type", "message_stop"} - }} - }); - - return events; -} - -// -// server_task_result_cmpl_partial -// -json server_task_result_cmpl_partial::to_json() { - GGML_ASSERT(is_updated && "update() must be called before to_json()"); - switch (res_type) { - case TASK_RESPONSE_TYPE_NONE: - return to_json_non_oaicompat(); - case TASK_RESPONSE_TYPE_OAI_CMPL: - return to_json_oaicompat(); - case TASK_RESPONSE_TYPE_OAI_CHAT: - return to_json_oaicompat_chat(); - case TASK_RESPONSE_TYPE_ANTHROPIC: - return to_json_anthropic(); - default: - GGML_ASSERT(false && "Invalid task_response_type"); - } -} - -json server_task_result_cmpl_partial::to_json_non_oaicompat() { - // non-OAI-compat JSON - json res = json { - {"index", index}, - {"content", content}, - {"tokens", tokens}, - {"stop", false}, - {"id_slot", id_slot}, - {"tokens_predicted", n_decoded}, - {"tokens_evaluated", n_prompt_tokens}, - }; - // populate the timings object when needed (usually for the last response or with timings_per_token enabled) - if (timings.prompt_n > 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - if (!prob_output.probs.empty()) { - res["completion_probabilities"] = completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs); - } - return res; -} - -json server_task_result_cmpl_partial::to_json_oaicompat() { - std::time_t t = std::time(0); - json logprobs = json(nullptr); // OAI default to null - if (prob_output.probs.size() > 0) { - logprobs = json{ - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - json res = json { - {"choices", json::array({ - json{ - {"text", content}, - {"index", index}, - {"logprobs", logprobs}, - {"finish_reason", nullptr}, - } - })}, - {"created", t}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "text_completion"}, - {"id", oaicompat_cmpl_id} - }; - - // extra fields for debugging purposes - if (verbose) { - res["__verbose"] = to_json_non_oaicompat(); - } - if (timings.prompt_n >= 0) { - res.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - res.push_back({"prompt_progress", progress.to_json()}); - } - - return res; -} - -json server_task_result_cmpl_partial::to_json_oaicompat_chat() { - bool first = n_decoded == 1; - std::time_t t = std::time(0); - json choices; - - std::vector deltas; - auto add_delta = [&](const json & delta) { - deltas.push_back({ - {"choices", json::array({ - json { - {"finish_reason", nullptr}, - {"index", index}, - {"delta", delta}, - }, - })}, - {"created", t}, - {"id", oaicompat_cmpl_id}, - {"model", oaicompat_model}, - {"system_fingerprint", build_info}, - {"object", "chat.completion.chunk"}, - }); - }; - // We have to send an initial update to conform to openai behavior - if (first || is_progress) { - add_delta({ - {"role", "assistant"}, - {"content", nullptr}, - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - add_delta(common_chat_msg_diff_to_json_oaicompat(diff)); - } - - if (!deltas.empty()) { - auto & last_json = deltas[deltas.size() - 1]; - GGML_ASSERT(last_json.at("choices").size() >= 1); - - if (prob_output.probs.size() > 0) { - last_json.at("choices").at(0)["logprobs"] = json { - {"content", completion_token_output::probs_vector_to_json({prob_output}, post_sampling_probs)}, - }; - } - - if (timings.prompt_n >= 0) { - last_json.push_back({"timings", timings.to_json()}); - } - if (is_progress) { - last_json.push_back({"prompt_progress", progress.to_json()}); - } - } - - return deltas; -} - -// -// server_task_result_embd -// -json server_task_result_embd::to_json() { - return res_type == TASK_RESPONSE_TYPE_OAI_EMBD - ? to_json_oaicompat() - : to_json_non_oaicompat(); -} - -json server_task_result_embd::to_json_non_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding}, - }; -} - -json server_task_result_embd::to_json_oaicompat() { - return json { - {"index", index}, - {"embedding", embedding[0]}, - {"tokens_evaluated", n_tokens}, - }; -} - -// -// server_task_result_rerank -// -json server_task_result_rerank::to_json() { - return json { - {"index", index}, - {"score", score}, - {"tokens_evaluated", n_tokens}, - }; -} - -json server_task_result_cmpl_partial::to_json_anthropic() { - json events = json::array(); - bool first = (n_decoded == 1); - static bool text_block_started = false; - - if (first) { - text_block_started = false; - - events.push_back({ - {"event", "message_start"}, - {"data", { - {"type", "message_start"}, - {"message", { - {"id", oaicompat_cmpl_id}, - {"type", "message"}, - {"role", "assistant"}, - {"content", json::array()}, - {"model", oaicompat_model}, - {"stop_reason", nullptr}, - {"stop_sequence", nullptr}, - {"usage", { - {"input_tokens", n_prompt_tokens}, - {"output_tokens", 0} - }} - }} - }} - }); - } - - for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (!text_block_started) { - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", 0}, - {"content_block", { - {"type", "text"}, - {"text", ""} - }} - }} - }); - text_block_started = true; - } - - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", 0}, - {"delta", { - {"type", "text_delta"}, - {"text", diff.content_delta} - }} - }} - }); - } - - if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; - - if (!diff.tool_call_delta.name.empty()) { - events.push_back({ - {"event", "content_block_start"}, - {"data", { - {"type", "content_block_start"}, - {"index", content_block_index}, - {"content_block", { - {"type", "tool_use"}, - {"id", diff.tool_call_delta.id}, - {"name", diff.tool_call_delta.name} - }} - }} - }); - } - - if (!diff.tool_call_delta.arguments.empty()) { - events.push_back({ - {"event", "content_block_delta"}, - {"data", { - {"type", "content_block_delta"}, - {"index", content_block_index}, - {"delta", { - {"type", "input_json_delta"}, - {"partial_json", diff.tool_call_delta.arguments} - }} - }} - }); - } - } - } - - return events; -} - -// -// server_task_result_error -// -json server_task_result_error::to_json() { - json res = format_error_response(err_msg, err_type); - if (err_type == ERROR_TYPE_EXCEED_CONTEXT_SIZE) { - res["n_prompt_tokens"] = n_prompt_tokens; - res["n_ctx"] = n_ctx; - } - return res; -} - -// -// server_task_result_metrics -// -json server_task_result_metrics::to_json() { - return json { - { "idle", n_idle_slots }, - { "processing", n_processing_slots }, - { "deferred", n_tasks_deferred }, - { "t_start", t_start }, - - { "n_prompt_tokens_processed_total", n_prompt_tokens_processed_total }, - { "t_tokens_generation_total", t_tokens_generation_total }, - { "n_tokens_predicted_total", n_tokens_predicted_total }, - { "t_prompt_processing_total", t_prompt_processing_total }, - - { "n_tokens_max", n_tokens_max }, - - { "n_prompt_tokens_processed", n_prompt_tokens_processed }, - { "t_prompt_processing", t_prompt_processing }, - { "n_tokens_predicted", n_tokens_predicted }, - { "t_tokens_generation", t_tokens_generation }, - - { "n_decode_total", n_decode_total }, - { "n_busy_slots_total", n_busy_slots_total }, - - { "slots", slots_data }, - }; -} - -// -// server_task_result_slot_save_load -// -json server_task_result_slot_save_load::to_json() { - if (is_save) { - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_saved", n_tokens }, - { "n_written", n_bytes }, - { "timings", { - { "save_ms", t_ms } - }}, - }; - } - - return json { - { "id_slot", id_slot }, - { "filename", filename }, - { "n_restored", n_tokens }, - { "n_read", n_bytes }, - { "timings", { - { "restore_ms", t_ms } - }}, - }; -} - -// -// server_task_result_slot_erase -// -json server_task_result_slot_erase::to_json() { - return json { - { "id_slot", id_slot }, - { "n_erased", n_erased }, - }; -} - -// -// server_task_result_apply_lora -// - -json server_task_result_apply_lora::to_json() { - return json {{ "success", true }}; -} - -// -// server_prompt_cache -// -size_t server_prompt_cache::size() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.size(); - } - - return res; -} - -size_t server_prompt_cache::n_tokens() const { - size_t res = 0; - - for (const auto & state : states) { - res += state.n_tokens(); - } - - return res; -} - -server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t state_size) { - // first check if the current state is contained fully in the cache - for (auto it = states.begin(); it != states.end(); ++it) { - const int cur_lcp_len = it->tokens.get_common_prefix(prompt.tokens); - - if (cur_lcp_len == (int) prompt.tokens.size()) { - SRV_WRN("%s", " - prompt is already in the cache, skipping\n"); - return nullptr; - } - } - - // next, remove any cached prompts that are fully contained in the current prompt - for (auto it = states.begin(); it != states.end();) { - const int len = it->tokens.get_common_prefix(prompt.tokens); - - if (len == (int) it->tokens.size()) { - SRV_WRN(" - removing obsolete cached prompt with length %d\n", len); - - it = states.erase(it); - } else { - ++it; - } - } - - std::vector state_data; - - // check if we can allocate enough memory for the new state - try { - state_data.resize(state_size); - } catch (const std::bad_alloc & e) { - SRV_ERR("failed to allocate memory for prompt cache state: %s\n", e.what()); - - limit_size = std::max(1, 0.4*size()); - - SRV_WRN(" - cache size limit reduced to %.3f MiB\n", limit_size / (1024.0 * 1024.0)); - - update(); - - return nullptr; - } - - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround - auto & cur = states.emplace_back(); - cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), - /*.data =*/ std::move(state_data), - /*.checkpoints =*/ prompt.checkpoints, - }; - - return &cur; -} - -bool server_prompt_cache::load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot) { - const int lcp_best = prompt.tokens.get_common_prefix(tokens_new); - - float f_keep_best = float(lcp_best) / prompt.tokens.size(); - float sim_best = float(lcp_best) / tokens_new.size(); - - SRV_WRN(" - looking for better prompt, base f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - auto it_best = states.end(); - - // find the most similar cached prompt, that would also preserve the most context - for (auto it = states.begin(); it != states.end(); ++it) { - const int lcp_cur = it->tokens.get_common_prefix(tokens_new); - - const float f_keep_cur = float(lcp_cur) / it->tokens.size(); - const float sim_cur = float(lcp_cur) / tokens_new.size(); - - // don't trash large prompts - if (f_keep_cur < 0.25f) { - continue; - } - - if (f_keep_best < f_keep_cur && sim_best < sim_cur) { - f_keep_best = f_keep_cur; - sim_best = sim_cur; - - it_best = it; - } - } - - if (it_best != states.end()) { - SRV_WRN(" - found better prompt with f_keep = %.3f, sim = %.3f\n", f_keep_best, sim_best); - - const size_t size = it_best->data.size(); - const size_t n = llama_state_seq_set_data_ext(ctx, it_best->data.data(), size, id_slot, 0); - if (n != size) { - SRV_WRN("failed to restore state with size %zu\n", size); - - return false; - } - - it_best->data.clear(); - it_best->data.shrink_to_fit(); - - prompt = std::move(*it_best); - - states.erase(it_best); - } - - return true; -} - -void server_prompt_cache::update() { - if (limit_size > 0) { - // always keep at least one state, regardless of the limits - while (states.size() > 1 && size() > limit_size) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache size limit reached, removing oldest entry (size = %.3f MiB)\n", states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - // average size per token - const float size_per_token = std::max(1.0f, float(size()) / (std::max(1, n_tokens()))); - - // dynamically increase the token limit if it can fit in the memory limit - const size_t limit_tokens_cur = limit_size > 0 ? std::max(limit_tokens, limit_size/size_per_token) : limit_tokens; - - if (limit_tokens > 0) { - while (states.size() > 1 && n_tokens() > limit_tokens_cur) { - if (states.empty()) { - break; - } - - SRV_WRN(" - cache token limit (%zu, est: %zu) reached, removing oldest entry (size = %.3f MiB)\n", - limit_tokens, limit_tokens_cur, states.front().size() / (1024.0 * 1024.0)); - - states.pop_front(); - } - } - - SRV_WRN(" - cache state: %zu prompts, %.3f MiB (limits: %.3f MiB, %zu tokens, %zu est)\n", - states.size(), size() / (1024.0 * 1024.0), limit_size / (1024.0 * 1024.0), limit_tokens, limit_tokens_cur); - - for (const auto & state : states) { - SRV_WRN(" - prompt %p: %7d tokens, checkpoints: %2zu, %9.3f MiB\n", - (const void *)&state, state.n_tokens(), state.checkpoints.size(), state.size() / (1024.0 * 1024.0)); - } -} diff --git a/llamacpp/native/src/server/server-task.h b/llamacpp/native/src/server/server-task.h deleted file mode 100644 index 0759094a0..000000000 --- a/llamacpp/native/src/server/server-task.h +++ /dev/null @@ -1,531 +0,0 @@ -#pragma once - -#include "common.h" -#include "llama.h" - -#include -#include -#include - -// TODO: prevent including the whole server-common.h as we only use server_tokens -#include "server-common.h" - -using json = nlohmann::ordered_json; - -enum server_task_type { - SERVER_TASK_TYPE_COMPLETION, - SERVER_TASK_TYPE_EMBEDDING, - SERVER_TASK_TYPE_RERANK, - SERVER_TASK_TYPE_INFILL, - SERVER_TASK_TYPE_CANCEL, - SERVER_TASK_TYPE_NEXT_RESPONSE, - SERVER_TASK_TYPE_METRICS, - SERVER_TASK_TYPE_SLOT_SAVE, - SERVER_TASK_TYPE_SLOT_RESTORE, - SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, -}; - -// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common -enum task_response_type { - TASK_RESPONSE_TYPE_NONE, // llama.cpp native format - TASK_RESPONSE_TYPE_OAI_CHAT, - TASK_RESPONSE_TYPE_OAI_CMPL, - TASK_RESPONSE_TYPE_OAI_EMBD, - TASK_RESPONSE_TYPE_ANTHROPIC, -}; - -enum stop_type { - STOP_TYPE_NONE, - STOP_TYPE_EOS, - STOP_TYPE_WORD, - STOP_TYPE_LIMIT, -}; - -struct task_params { - bool stream = true; - bool include_usage = false; - bool cache_prompt = true; // remember the prompt to avoid reprocessing all prompt - bool return_tokens = false; - bool return_progress = false; - - int32_t n_keep = 0; // number of tokens to keep from initial prompt - int32_t n_discard = 0; // number of tokens after n_keep that may be discarded when shifting context, 0 defaults to half - int32_t n_predict = -1; // new tokens to predict - int32_t n_indent = 0; // minimum line indentation for the generated text in number of whitespace characters - int32_t n_cmpl = 1; // number of completions to generate from this prompt - - int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - - int64_t t_max_prompt_ms = -1; // TODO: implement - int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit - - std::vector lora; - - std::vector antiprompt; - std::vector response_fields; - - bool timings_per_token = false; - bool post_sampling_probs = false; - - struct common_params_sampling sampling; - struct common_params_speculative speculative; - - // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_syntax oaicompat_chat_syntax; - - // Embeddings - int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - json format_logit_bias(const std::vector & logit_bias) const; - json to_json(bool only_metrics = false) const; -}; - -// struct for tracking the state of a task (e.g., for streaming) -struct task_result_state { - // tracking diffs for partial tool calls - std::vector diffs; - common_chat_syntax oaicompat_chat_syntax; - common_chat_msg chat_msg; - std::string generated_text; // append new chunks of generated text here - std::vector generated_tool_call_ids; - - task_result_state(const common_chat_syntax & oaicompat_chat_syntax) - : oaicompat_chat_syntax(oaicompat_chat_syntax) {} - - // parse partial tool calls and update the internal state - common_chat_msg update_chat_msg( - const std::string & text_added, - bool is_partial, - std::vector & diffs); -}; - -struct server_task { - int id = -1; // to be filled by server_queue - int index = -1; // used when there are multiple prompts (batch request) - - // used by SERVER_TASK_TYPE_CANCEL - int id_target = -1; - int id_slot = -1; - - // used by parallel sampling (multiple completions from same prompt) - size_t n_children = 0; // number of tasks reusing this prompt - int id_parent = -1; - - // used by SERVER_TASK_TYPE_INFERENCE - task_params params; - server_tokens tokens; - - // only used by CLI, this delegates the tokenization to the server - json cli_input = nullptr; - std::vector cli_files; - - server_task_type type; - - // used by SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE - struct slot_action { - int slot_id; - std::string filename; - std::string filepath; - }; - slot_action slot_action; - - // used by SERVER_TASK_TYPE_METRICS - bool metrics_reset_bucket = false; - - // used by SERVER_TASK_TYPE_SET_LORA - std::vector set_lora; - - server_task() = default; - - server_task(server_task_type type) : type(type) {} - - int32_t n_tokens() const { - return tokens.size(); - } - - static task_params params_from_json_cmpl( - const llama_context * ctx, - const common_params & params_base, - const json & data); - - // utility function - static std::unordered_set get_list_id(const std::vector & tasks) { - std::unordered_set ids(tasks.size()); - for (size_t i = 0; i < tasks.size(); i++) { - ids.insert(tasks[i].id); - } - return ids; - } - - server_task create_child(int id_parent, int id_child, int idx) const { - server_task copy; - copy.id = id_child; - copy.index = idx; - copy.id_parent = id_parent; - copy.params = params; - copy.type = type; - copy.tokens = tokens.clone(); - return copy; - } - - // the task will be moved into queue, then onto slots - // however, the state must be kept by caller (e.g., HTTP thread) - task_result_state create_state() const { - return task_result_state(params.oaicompat_chat_syntax); - } -}; - -struct result_timings { - int32_t cache_n = -1; - - int32_t prompt_n = -1; - double prompt_ms; - double prompt_per_token_ms; - double prompt_per_second; - - int32_t predicted_n = -1; - double predicted_ms; - double predicted_per_token_ms; - double predicted_per_second; - - // Optional speculative metrics - only included when > 0 - int32_t draft_n = 0; - int32_t draft_n_accepted = 0; - - json to_json() const; -}; - -struct result_prompt_progress { - int32_t total = 0; - int32_t cache = 0; - int32_t processed = 0; - int64_t time_ms = 0; - - json to_json() const; -}; - -struct server_task_result { - int id = -1; - int id_slot = -1; - virtual bool is_error() { - // only used by server_task_result_error - return false; - } - virtual bool is_stop() { - // only used by server_task_result_cmpl_* - return true; - } - virtual int get_index() { - return -1; - } - virtual void update(task_result_state &) { - // only used by server_task_result_cmpl_* - } - virtual json to_json() = 0; - virtual ~server_task_result() = default; -}; - -// using shared_ptr for polymorphism of server_task_result -using server_task_result_ptr = std::unique_ptr; - -struct completion_token_output { - llama_token tok; - float prob; - std::string text_to_send; - struct prob_info { - llama_token tok; - std::string txt; - float prob; - }; - std::vector probs; - - json to_json(bool post_sampling_probs) const; - - static json probs_vector_to_json(const std::vector & probs, bool post_sampling_probs); - - static float logarithm(float x); - - static std::vector str_to_bytes(const std::string & str); - -}; - -struct server_task_result_cmpl_final : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - bool stream; - bool include_usage; - result_timings timings; - std::string prompt; - - bool truncated; - int32_t n_decoded; - int32_t n_prompt_tokens; - int32_t n_tokens_cached; - bool has_new_line; - std::string stopping_word; - stop_type stop = STOP_TYPE_NONE; - - bool post_sampling_probs; - std::vector probs_output; - std::vector response_fields; - - task_params generation_params; - - // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - common_chat_msg oaicompat_msg; // to be populated by update() - - std::vector oaicompat_msg_diffs; // to be populated by update() - bool is_updated = false; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return true; // in stream mode, final responses are considered stop - } - - virtual json to_json() override; - - virtual void update(task_result_state & state) override { - is_updated = true; - oaicompat_msg = state.update_chat_msg(content, false, oaicompat_msg_diffs); - } - - json to_json_non_oaicompat(); - - json to_json_oaicompat(); - - json to_json_oaicompat_chat(); - - json to_json_oaicompat_chat_stream(); - - json to_json_anthropic(); - - json to_json_anthropic_stream(); -}; - -struct server_task_result_cmpl_partial : server_task_result { - int index = 0; - - std::string content; - llama_tokens tokens; - - int32_t n_decoded; - int32_t n_prompt_tokens; - - bool post_sampling_probs; - bool is_progress = false; - completion_token_output prob_output; - result_timings timings; - result_prompt_progress progress; - - // response formatting - bool verbose = false; - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - std::string oaicompat_model; - std::string oaicompat_cmpl_id; - std::vector oaicompat_msg_diffs; // to be populated by update() - bool is_updated = false; - - virtual int get_index() override { - return index; - } - - virtual bool is_stop() override { - return false; // in stream mode, partial responses are not considered stop - } - - virtual json to_json() override; - - virtual void update(task_result_state & state) override { - is_updated = true; - state.update_chat_msg(content, true, oaicompat_msg_diffs); - } - - json to_json_non_oaicompat(); - - json to_json_oaicompat(); - - json to_json_oaicompat_chat(); - - json to_json_anthropic(); -}; - -struct server_task_result_embd : server_task_result { - int index = 0; - std::vector> embedding; - - int32_t n_tokens; - - // response formatting - task_response_type res_type = TASK_RESPONSE_TYPE_NONE; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override; - - json to_json_non_oaicompat(); - - json to_json_oaicompat(); -}; - -struct server_task_result_rerank : server_task_result { - int index = 0; - float score = -1e6; - - int32_t n_tokens; - - virtual int get_index() override { - return index; - } - - virtual json to_json() override; -}; - -struct server_task_result_error : server_task_result { - int index = 0; - error_type err_type = ERROR_TYPE_SERVER; - std::string err_msg; - - // for ERROR_TYPE_EXCEED_CONTEXT_SIZE - int32_t n_prompt_tokens = 0; - int32_t n_ctx = 0; - - virtual bool is_error() override { - return true; - } - - virtual json to_json() override; -}; - -struct server_task_result_metrics : server_task_result { - int n_idle_slots; - int n_processing_slots; - int n_tasks_deferred; - int64_t t_start; - - // TODO: somehow reuse server_metrics in the future, instead of duplicating the fields - uint64_t n_prompt_tokens_processed_total = 0; - uint64_t t_prompt_processing_total = 0; - uint64_t n_tokens_predicted_total = 0; - uint64_t t_tokens_generation_total = 0; - - uint64_t n_tokens_max = 0; - - uint64_t n_prompt_tokens_processed = 0; - uint64_t t_prompt_processing = 0; - - uint64_t n_tokens_predicted = 0; - uint64_t t_tokens_generation = 0; - - uint64_t n_decode_total = 0; - uint64_t n_busy_slots_total = 0; - - // while we can also use std::vector this requires copying the slot object which can be quite messy - // therefore, we use json to temporarily store the slot.to_json() result - json slots_data = json::array(); - - virtual json to_json() override; -}; - -struct server_task_result_slot_save_load : server_task_result { - std::string filename; - bool is_save; // true = save, false = load - - size_t n_tokens; - size_t n_bytes; - double t_ms; - - virtual json to_json() override; -}; - -struct server_task_result_slot_erase : server_task_result { - size_t n_erased; - - virtual json to_json() override; -}; - -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override; -}; - -struct server_prompt_checkpoint { - llama_pos pos_min; - llama_pos pos_max; - - std::vector data; - - size_t size() const { - return data.size(); - } -}; - -struct server_prompt { - server_tokens tokens; - - std::vector data; - - std::list checkpoints; - - size_t size() const { - size_t res = data.size(); - - for (const auto & checkpoint : checkpoints) { - res += checkpoint.size(); - } - - return res; - } - - int n_tokens() const { - return tokens.size(); - } - - server_prompt clone() const { - return server_prompt { - tokens.clone(), - data, - checkpoints - }; - } -}; - -struct server_prompt_cache { - server_prompt_cache(int32_t limit_size_mib, size_t limit_tokens) { - this->limit_size = 1024ull*1024ull*(limit_size_mib < 0 ? 0 : limit_size_mib); - this->limit_tokens = limit_tokens; - } - - std::list states; - - // in bytes, 0 = no limit - size_t limit_size = 0; - - // in tokens, 0 = no limit - size_t limit_tokens = 0; - - size_t size() const; - - size_t n_tokens() const; - - server_prompt * alloc(const server_prompt & prompt, size_t state_size); - - bool load(server_prompt & prompt, const server_tokens & tokens_new, llama_context * ctx, int32_t id_slot); - - void update(); -}; diff --git a/llamacpp/native/src/server/server.cpp b/llamacpp/native/src/server/server.cpp deleted file mode 100644 index d5bef3df4..000000000 --- a/llamacpp/native/src/server/server.cpp +++ /dev/null @@ -1,306 +0,0 @@ -#include "server-context.h" -#include "server-http.h" -#include "server-models.h" - -#include "arg.h" -#include "common.h" -#include "llama.h" -#include "log.h" - -#include -#include -#include // for std::thread::hardware_concurrency - -#if defined(_WIN32) -#include -#endif - -static std::function shutdown_handler; -static std::atomic_flag is_terminating = ATOMIC_FLAG_INIT; - -static inline void signal_handler(int signal) { - if (is_terminating.test_and_set()) { - // in case it hangs, we can force terminate the server by hitting Ctrl+C twice - // this is for better developer experience, we can remove when the server is stable enough - fprintf(stderr, "Received second interrupt, terminating immediately.\n"); - exit(1); - } - - shutdown_handler(signal); -} - -// wrapper function that handles exceptions and logs errors -// this is to make sure handler_t never throws exceptions; instead, it returns an error response -static server_http_context::handler_t ex_wrapper(server_http_context::handler_t func) { - return [func = std::move(func)](const server_http_req & req) -> server_http_res_ptr { - std::string message; - error_type error; - try { - return func(req); - } catch (const std::invalid_argument & e) { - // treat invalid_argument as invalid request (400) - error = ERROR_TYPE_INVALID_REQUEST; - message = e.what(); - } catch (const std::exception & e) { - // treat other exceptions as server error (500) - error = ERROR_TYPE_SERVER; - message = e.what(); - } catch (...) { - error = ERROR_TYPE_SERVER; - message = "unknown error"; - } - - auto res = std::make_unique(); - res->status = 500; - try { - json error_data = format_error_response(message, error); - res->status = json_value(error_data, "code", 500); - res->data = safe_json_to_str({{ "error", error_data }}); - SRV_WRN("got exception: %s\n", res->data.c_str()); - } catch (const std::exception & e) { - SRV_ERR("got another exception: %s | while handling exception: %s\n", e.what(), message.c_str()); - res->data = "Internal Server Error"; - } - return res; - }; -} - -int main(int argc, char ** argv, char ** envp) { - // own arguments required by this example - common_params params; - - if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_SERVER)) { - return 1; - } - - // TODO: should we have a separate n_parallel parameter for the server? - // https://github.com/ggml-org/llama.cpp/pull/16736#discussion_r2483763177 - // TODO: this is a common configuration that is suitable for most local use cases - // however, overriding the parameters is a bit confusing - figure out something more intuitive - if (params.n_parallel == 1 && params.kv_unified == false && !params.has_speculative()) { - LOG_WRN("%s: setting n_parallel = 4 and kv_unified = true (add -kvu to disable this)\n", __func__); - - params.n_parallel = 4; - params.kv_unified = true; - } - - // for consistency between server router mode and single-model mode, we set the same model name as alias - if (params.model_alias.empty() && !params.model.name.empty()) { - params.model_alias = params.model.name; - } - - common_init(); - - // struct that contains llama context and inference - server_context ctx_server; - - llama_backend_init(); - llama_numa_init(params.numa); - - LOG_INF("system info: n_threads = %d, n_threads_batch = %d, total_threads = %d\n", params.cpuparams.n_threads, params.cpuparams_batch.n_threads, std::thread::hardware_concurrency()); - LOG_INF("\n"); - LOG_INF("%s\n", common_params_get_system_info(params).c_str()); - LOG_INF("\n"); - - server_http_context ctx_http; - if (!ctx_http.init(params)) { - LOG_ERR("%s: failed to initialize HTTP server\n", __func__); - return 1; - } - - // - // Router - // - - // register API routes - server_routes routes(params, ctx_server, [&ctx_http]() { return ctx_http.is_ready.load(); }); - - bool is_router_server = params.model.path.empty(); - std::optional models_routes{}; - if (is_router_server) { - // setup server instances manager - models_routes.emplace(params, argc, argv, envp); - - // proxy handlers - // note: routes.get_health stays the same - routes.get_metrics = models_routes->proxy_get; - routes.post_props = models_routes->proxy_post; - routes.get_api_show = models_routes->proxy_get; - routes.post_completions = models_routes->proxy_post; - routes.post_completions_oai = models_routes->proxy_post; - routes.post_chat_completions = models_routes->proxy_post; - routes.post_anthropic_messages = models_routes->proxy_post; - routes.post_anthropic_count_tokens = models_routes->proxy_post; - routes.post_infill = models_routes->proxy_post; - routes.post_embeddings = models_routes->proxy_post; - routes.post_embeddings_oai = models_routes->proxy_post; - routes.post_rerank = models_routes->proxy_post; - routes.post_tokenize = models_routes->proxy_post; - routes.post_detokenize = models_routes->proxy_post; - routes.post_apply_template = models_routes->proxy_post; - routes.get_lora_adapters = models_routes->proxy_get; - routes.post_lora_adapters = models_routes->proxy_post; - routes.get_slots = models_routes->proxy_get; - routes.post_slots = models_routes->proxy_post; - - // custom routes for router - routes.get_props = models_routes->get_router_props; - routes.get_models = models_routes->get_router_models; - ctx_http.post("/models/load", ex_wrapper(models_routes->post_router_models_load)); - ctx_http.post("/models/unload", ex_wrapper(models_routes->post_router_models_unload)); - ctx_http.post("/models/status", ex_wrapper(models_routes->post_router_models_status)); - } - - ctx_http.get ("/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) - ctx_http.get ("/v1/health", ex_wrapper(routes.get_health)); // public endpoint (no API key check) - ctx_http.get ("/metrics", ex_wrapper(routes.get_metrics)); - ctx_http.get ("/props", ex_wrapper(routes.get_props)); - ctx_http.post("/props", ex_wrapper(routes.post_props)); - ctx_http.post("/api/show", ex_wrapper(routes.get_api_show)); - ctx_http.get ("/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/v1/models", ex_wrapper(routes.get_models)); // public endpoint (no API key check) - ctx_http.get ("/api/tags", ex_wrapper(routes.get_models)); // ollama specific endpoint. public endpoint (no API key check) - ctx_http.post("/completion", ex_wrapper(routes.post_completions)); // legacy - ctx_http.post("/completions", ex_wrapper(routes.post_completions)); - ctx_http.post("/v1/completions", ex_wrapper(routes.post_completions_oai)); - ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); - ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint - ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API - ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting - ctx_http.post("/infill", ex_wrapper(routes.post_infill)); - ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy - ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings)); - ctx_http.post("/v1/embeddings", ex_wrapper(routes.post_embeddings_oai)); - ctx_http.post("/rerank", ex_wrapper(routes.post_rerank)); - ctx_http.post("/reranking", ex_wrapper(routes.post_rerank)); - ctx_http.post("/v1/rerank", ex_wrapper(routes.post_rerank)); - ctx_http.post("/v1/reranking", ex_wrapper(routes.post_rerank)); - ctx_http.post("/tokenize", ex_wrapper(routes.post_tokenize)); - ctx_http.post("/detokenize", ex_wrapper(routes.post_detokenize)); - ctx_http.post("/apply-template", ex_wrapper(routes.post_apply_template)); - // LoRA adapters hotswap - ctx_http.get ("/lora-adapters", ex_wrapper(routes.get_lora_adapters)); - ctx_http.post("/lora-adapters", ex_wrapper(routes.post_lora_adapters)); - // Save & load slots - ctx_http.get ("/slots", ex_wrapper(routes.get_slots)); - ctx_http.post("/slots/:id_slot", ex_wrapper(routes.post_slots)); - - // - // Start the server - // - - std::function clean_up; - - if (is_router_server) { - LOG_INF("%s: starting router server, no model will be loaded in this process\n", __func__); - - clean_up = [&models_routes]() { - SRV_INF("%s: cleaning up before exit...\n", __func__); - if (models_routes.has_value()) { - models_routes->models.unload_all(); - } - llama_backend_free(); - }; - - if (!ctx_http.start()) { - clean_up(); - LOG_ERR("%s: exiting due to HTTP server error\n", __func__); - return 1; - } - ctx_http.is_ready.store(true); - - shutdown_handler = [&](int) { - ctx_http.stop(); - }; - - } else { - // setup clean up function, to be called before exit - clean_up = [&ctx_http, &ctx_server]() { - SRV_INF("%s: cleaning up before exit...\n", __func__); - ctx_http.stop(); - ctx_server.terminate(); - llama_backend_free(); - }; - - // start the HTTP server before loading the model to be able to serve /health requests - if (!ctx_http.start()) { - clean_up(); - LOG_ERR("%s: exiting due to HTTP server error\n", __func__); - return 1; - } - - // load the model - LOG_INF("%s: loading model\n", __func__); - - if (!ctx_server.load_model(params)) { - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); - } - LOG_ERR("%s: exiting due to model loading error\n", __func__); - return 1; - } - - ctx_server.init(); - ctx_http.is_ready.store(true); - - LOG_INF("%s: model loaded\n", __func__); - - shutdown_handler = [&](int) { - // this will unblock start_loop() - ctx_server.terminate(); - }; - } - - // TODO: refactor in common/console -#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = signal_handler; - sigemptyset (&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); - sigaction(SIGTERM, &sigint_action, NULL); -#elif defined (_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (signal_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif - - if (is_router_server) { - LOG_INF("%s: router server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); - LOG_INF("%s: NOTE: router mode is experimental\n", __func__); - LOG_INF("%s: it is not recommended to use this mode in untrusted environments\n", __func__); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); // keep the main thread alive - } - - // when the HTTP server stops, clean up and exit - clean_up(); - } else { - LOG_INF("%s: server is listening on %s\n", __func__, ctx_http.listening_address.c_str()); - LOG_INF("%s: starting the main loop...\n", __func__); - - // optionally, notify router server that this instance is ready - const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); - std::thread monitor_thread; - if (router_port != nullptr) { - monitor_thread = server_models::setup_child_server(params, std::atoi(router_port), params.model_alias, shutdown_handler); - } - - // this call blocks the main thread until queue_tasks.terminate() is called - ctx_server.start_loop(); - - clean_up(); - if (ctx_http.thread.joinable()) { - ctx_http.thread.join(); - } - if (monitor_thread.joinable()) { - monitor_thread.join(); - } - llama_memory_breakdown_print(ctx_server.get_llama_context()); - } - - return 0; -}