From 2ba6efc561d1c00862535f20976e0e6d426e2cdc Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Dec 2024 11:28:25 +0100 Subject: [PATCH 01/11] slot.can_batch_with --- examples/server/server.cpp | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 30ff3b14957dc..055e2c5b862b5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1090,6 +1090,10 @@ struct server_slot { return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK; } + bool can_batch_with(server_slot & other_slot) { + return is_non_causal() == other_slot.is_non_causal(); + } + bool has_budget(const common_params & global_params) { if (params.n_predict == -1 && global_params.n_predict == -1) { return true; // limitless @@ -2564,11 +2568,8 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if this is an embedding or non-embedding batch - // if we've added sampled tokens above, we are in non-embedding mode - // -1: none, 0: non-embedding, 1: embedding - // TODO: make enum - int32_t batch_type = batch.n_tokens > 0 ? 0 : -1; + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { @@ -2733,11 +2734,10 @@ struct server_context { } } - // check that we are in the right batch_type, if not defer the slot - int slot_type = slot.is_non_causal(); - if (batch_type == -1) { - batch_type = slot_type; - } else if (batch_type != slot_type) { + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (slot_batched && !slot_batched->can_batch_with(slot)) { continue; } @@ -2809,7 +2809,7 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); // make sure we're in the right embedding mode - llama_set_embeddings(ctx, batch_type == 1); + llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal()); // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { From 9d84127fa6d79149f18d6ddad8bfbd98480e8a7f Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Dec 2024 16:11:02 +0100 Subject: [PATCH 02/11] lora per request --- examples/server/server.cpp | 104 ++++++++++++------------ examples/server/tests/unit/test_lora.py | 34 ++++++++ examples/server/utils.hpp | 41 ++++++++++ 3 files changed, 125 insertions(+), 54 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 055e2c5b862b5..67d704f1be772 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -64,7 +64,6 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, - SERVER_TASK_TYPE_SET_LORA, }; // https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11 @@ -91,6 +90,8 @@ struct slot_params { int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit + std::vector lora; + std::vector antiprompt; std::vector response_fields; bool timings_per_token = false; @@ -114,6 +115,11 @@ struct slot_params { samplers.emplace_back(common_sampler_type_to_str(sampler)); } + json lora = json::array(); + for (size_t i = 0; i < this->lora.size(); ++i) { + lora.push_back({{"id", i}, {"scale", this->lora[i].scale}}); + } + return json { {"n_predict", n_predict}, // Server configured n_predict {"seed", sampling.seed}, @@ -154,6 +160,7 @@ struct slot_params { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"lora", lora}, }; } }; @@ -189,6 +196,7 @@ struct server_task { const llama_model * model, const llama_context * ctx, const common_params & params_base, + const std::vector & base_lora, const json & data) { slot_params params; @@ -245,6 +253,16 @@ struct server_task { params.speculative.n_min = std::max(params.speculative.n_min, 2); params.speculative.n_max = std::max(params.speculative.n_max, 0); + if (data.contains("lora")) { + if (data.at("lora").is_array()) { + params.lora = parse_lora_request(base_lora, data.at("lora")); + } else { + throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); + } + } else { + params.lora = base_lora; + } + // TODO: add more sanity checks for the input parameters if (params.sampling.penalty_last_n < -1) { @@ -989,12 +1007,6 @@ struct server_task_result_slot_erase : server_task_result { } }; -struct server_task_result_apply_lora : server_task_result { - virtual json to_json() override { - return json {{ "success", true }}; - } -}; - struct server_slot { int id; int id_task = -1; @@ -1009,6 +1021,8 @@ struct server_slot { common_speculative * spec = nullptr; + std::vector lora; + // the index relative to completion multi-task request size_t index = 0; @@ -1091,7 +1105,8 @@ struct server_slot { } bool can_batch_with(server_slot & other_slot) { - return is_non_causal() == other_slot.is_non_causal(); + return is_non_causal() == other_slot.is_non_causal() + && are_lora_equal(lora, other_slot.lora); } bool has_budget(const common_params & global_params) { @@ -1503,7 +1518,7 @@ struct server_context { llama_model * model = nullptr; llama_context * ctx = nullptr; - std::vector loras; + std::vector lora; llama_model * model_dft = nullptr; llama_context_params cparams_dft; @@ -1570,7 +1585,7 @@ struct server_context { model = llama_init.model; ctx = llama_init.context; - loras = llama_init.lora_adapters; + lora = llama_init.lora_adapters; if (model == nullptr) { SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str()); @@ -1776,6 +1791,12 @@ struct server_context { slot.params = std::move(task.params); slot.prompt_tokens = std::move(task.prompt_tokens); + if (!are_lora_equal(task.params.lora, slot.lora)) { + // if lora is changed, we cannot reuse cached tokens + slot.cache_tokens.clear(); + slot.lora = std::move(task.params.lora); + } + SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str()); if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) { @@ -2465,13 +2486,6 @@ struct server_context { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; - case SERVER_TASK_TYPE_SET_LORA: - { - common_lora_adapters_apply(ctx, loras); - auto res = std::make_unique(); - res->id = task.id; - queue_results.send(std::move(res)); - } break; } } @@ -2808,8 +2822,12 @@ struct server_context { SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); - // make sure we're in the right embedding mode - llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal()); + if (slot_batched) { + // make sure we're in the right embedding mode + llama_set_embeddings(ctx, slot_batched->is_non_causal()); + // apply lora, only need to do it once per batch + common_lora_adapters_apply(ctx, slot_batched->lora); + } // process the created batch of tokens for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { @@ -3530,7 +3548,12 @@ int main(int argc, char ** argv) { task.index = i; task.prompt_tokens = std::move(tokenized_prompts[i]); - task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data); + task.params = server_task::params_from_json_cmpl( + ctx_server.model, + ctx_server.ctx, + ctx_server.params_base, + ctx_server.lora, + data); task.id_selected_slot = json_value(data, "id_slot", -1); // OAI-compat @@ -3944,8 +3967,8 @@ int main(int argc, char ** argv) { const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) { json result = json::array(); - for (size_t i = 0; i < ctx_server.loras.size(); ++i) { - auto & lora = ctx_server.loras[i]; + for (size_t i = 0; i < ctx_server.lora.size(); ++i) { + auto & lora = ctx_server.lora[i]; result.push_back({ {"id", i}, {"path", lora.path}, @@ -3957,40 +3980,13 @@ int main(int argc, char ** argv) { }; const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) { - const std::vector body = json::parse(req.body); - int max_idx = ctx_server.loras.size(); - - // clear existing value - for (auto & lora : ctx_server.loras) { - lora.scale = 0.0f; - } - - // set value - for (auto entry : body) { - int id = entry.at("id"); - float scale = entry.at("scale"); - if (0 <= id && id < max_idx) { - ctx_server.loras[id].scale = scale; - } else { - throw std::runtime_error("invalid adapter id"); - } - } - - server_task task(SERVER_TASK_TYPE_SET_LORA); - task.id = ctx_server.queue_tasks.get_new_id(); - ctx_server.queue_results.add_waiting_task_id(task.id); - ctx_server.queue_tasks.post(task); - - server_task_result_ptr result = ctx_server.queue_results.recv(task.id); - ctx_server.queue_results.remove_waiting_task_id(task.id); - - if (result->is_error()) { - res_error(res, result->to_json()); + const json body = json::parse(req.body); + if (!body.is_array()) { + res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); return; } - - GGML_ASSERT(dynamic_cast(result.get()) != nullptr); - res_ok(res, result->to_json()); + ctx_server.lora = parse_lora_request(ctx_server.lora, body); + res_ok(res, json{{"success", true}}); }; // diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 7496154493917..9167c2f8e64dd 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -40,3 +40,37 @@ def test_lora(scale: float, re_content: str): assert res.status_code == 200 assert match_regex(re_content, res.body["content"]) + +def test_lora_per_request(): + global server + server.n_slots = 4 + server.start() + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Look in thy glass" + lora_config = [ + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ] + # FIXME: tesing with scale between 0.0 and 1.0 (i.e. 0.2, 0.5, 0.7) produces unreliable results + + tasks = [( + server.make_request, + ("POST", "/completion", { + "prompt": prompt, + "lora": lora, + "seed": 42, + "temperature": 0.0, + }) + ) for lora, re_test in lora_config] + results = parallel_function_calls(tasks) + + print(results) + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert match_regex(re_test, res.body["content"]) diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 334f2f19207ef..573c379f1710d 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -771,3 +771,44 @@ static std::vector get_token_probabilities(llama_context * ctx return cur; } + +static bool are_lora_equal( + const std::vector & l1, + const std::vector & l2) { + if (l1.size() != l2.size()) { + return false; + } + for (size_t i = 0; i < l1.size(); ++i) { + // we don't check lora.path to reduce the time complexity + if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) { + return false; + } + } + return true; +} + +// parse lora config from JSON request, returned a copy of base_lora with updated scale +static std::vector parse_lora_request( + const std::vector & base_lora, + const json & data) { + std::vector lora(base_lora); + int max_idx = lora.size(); + + // clear existing value + for (auto & entry : lora) { + entry.scale = 0.0f; + } + + // set value + for (auto entry : data) { + int id = json_value(entry, "id", -1); + float scale = json_value(entry, "scale", 0.0f); + if (0 <= id && id < max_idx) { + lora[id].scale = scale; + } else { + throw std::runtime_error("invalid adapter id"); + } + } + + return lora; +} From 9947b0776fdb48ac84e6241eb5d791ff6809af63 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Dec 2024 18:31:58 +0100 Subject: [PATCH 03/11] test: force disable cache prompt --- examples/server/tests/unit/test_lora.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 9167c2f8e64dd..ea50927331a81 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -66,6 +66,7 @@ def test_lora_per_request(): "lora": lora, "seed": 42, "temperature": 0.0, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed }) ) for lora, re_test in lora_config] results = parallel_function_calls(tasks) From b9b2b6371aae5d80f271cd46a177adc381cd131d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 27 Dec 2024 20:22:49 +0100 Subject: [PATCH 04/11] move can_batch_with check --- examples/server/server.cpp | 14 +++++++------- examples/server/tests/unit/test_lora.py | 3 +-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 67d704f1be772..d6e32084ac348 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2588,6 +2588,13 @@ struct server_context { // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (slot_batched && !slot_batched->can_batch_with(slot)) { + continue; + } + // this slot still has a prompt to be processed if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) { auto & prompt_tokens = slot.prompt_tokens; @@ -2748,13 +2755,6 @@ struct server_context { } } - // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; - } - // keep only the common part if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) { // could not partially delete (likely using a non-Transformer model) diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index ea50927331a81..0751f156bf2e8 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -68,10 +68,9 @@ def test_lora_per_request(): "temperature": 0.0, "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed }) - ) for lora, re_test in lora_config] + ) for lora, _ in lora_config] results = parallel_function_calls(tasks) - print(results) assert all([res.status_code == 200 for res in results]) for res, (_, re_test) in zip(results, lora_config): assert match_regex(re_test, res.body["content"]) From 076346db8a85ceabe485a852d577d05bdfb2f308 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 28 Dec 2024 16:16:57 +0100 Subject: [PATCH 05/11] fix condition --- examples/server/server.cpp | 23 ++++++++++++++++------- examples/server/tests/README.md | 6 ++++++ examples/server/tests/unit/test_lora.py | 5 ++--- 3 files changed, 24 insertions(+), 10 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index d6e32084ac348..a5caf6ac994f4 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2558,12 +2558,22 @@ struct server_context { // start populating the batch for this iteration common_batch_clear(batch); + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + // frist, add sampled tokens from any ongoing sequences for (auto & slot : slots) { if (slot.state != SLOT_STATE_GENERATING) { continue; } + // check if we can batch this slot with the previous one + if (!slot_batched) { + slot_batched = &slot; + } else if (slot_batched && !slot_batched->can_batch_with(slot)) { + continue; + } + slot.i_batch = batch.n_tokens; common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); @@ -2582,17 +2592,16 @@ struct server_context { int32_t n_batch = llama_n_batch(ctx); int32_t n_ubatch = llama_n_ubatch(ctx); - // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; - // next, batch any pending prompts without exceeding n_batch if (params_base.cont_batching || batch.n_tokens == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one - if (!slot_batched) { - slot_batched = &slot; - } else if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; + if (slot.is_processing()) { + if (!slot_batched) { + slot_batched = &slot; + } else if (slot_batched && !slot_batched->can_batch_with(slot)) { + continue; + } } // this slot still has a prompt to be processed diff --git a/examples/server/tests/README.md b/examples/server/tests/README.md index fa3d0a2f5ff66..5787276abac43 100644 --- a/examples/server/tests/README.md +++ b/examples/server/tests/README.md @@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d DEBUG=1 ./tests.sh -s -v -x ``` +To run single test unit: + +```shell +./tests.sh unit/test_{name of test case here}.py -v -x +``` + Hint: You can compile and run test in single command, useful for local developement: ```shell diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 0751f156bf2e8..68a7be17e8504 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -52,12 +52,11 @@ def test_lora_per_request(): lora_config = [ ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), - ( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ), - ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), + ( [{"id": 0, "scale": 0.3}], "(special|thing|gifted)+" ), + ( [{"id": 0, "scale": 0.7}], "(far|from|home|away)+" ), ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), ( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ), ] - # FIXME: tesing with scale between 0.0 and 1.0 (i.e. 0.2, 0.5, 0.7) produces unreliable results tasks = [( server.make_request, From 367f0ab1b4975da27233cd955809c020cdcf9438 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 1 Jan 2025 19:36:42 +0100 Subject: [PATCH 06/11] add slow test with llama 8b --- examples/server/tests/requirements.txt | 1 + examples/server/tests/unit/test_lora.py | 59 ++++++++++++++++--- .../server/tests/unit/test_speculative.py | 10 +--- examples/server/tests/utils.py | 21 +++++++ 4 files changed, 73 insertions(+), 18 deletions(-) diff --git a/examples/server/tests/requirements.txt b/examples/server/tests/requirements.txt index 074b9d47bddce..15d024914e841 100644 --- a/examples/server/tests/requirements.txt +++ b/examples/server/tests/requirements.txt @@ -5,3 +5,4 @@ numpy~=1.26.4 openai~=1.55.3 prometheus-client~=0.20.0 requests~=2.32.3 +wget~=3.2 diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 68a7be17e8504..0481e62c02010 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -10,15 +10,7 @@ def create_server(): global server server = ServerPreset.stories15m_moe() - # download lora file if needed - file_name = LORA_FILE_URL.split('/').pop() - lora_file = f'../../../{file_name}' - if not os.path.exists(lora_file): - print(f"Downloading {LORA_FILE_URL} to {lora_file}") - with open(lora_file, 'wb') as f: - f.write(requests.get(LORA_FILE_URL).content) - print(f"Done downloading lora file") - server.lora_files = [lora_file] + server.lora_files = [download_file(LORA_FILE_URL)] @pytest.mark.parametrize("scale,re_content", [ @@ -73,3 +65,52 @@ def test_lora_per_request(): assert all([res.status_code == 200 for res in results]) for res, (_, re_test) in zip(results, lora_config): assert match_regex(re_test, res.body["content"]) + + +@pytest.mark.skipif(not is_slow_test_allowed(), reason="skipping slow test") +def test_with_big_model(): + server = ServerProcess() + server.model_hf_repo = "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF" + server.model_hf_file = "Meta-Llama-3.1-8B-Instruct-IQ2_M.gguf" + server.model_alias = "Llama-3.2-8B-Instruct" + server.n_slots = 4 + server.n_ctx = server.n_slots * 1024 + server.n_predict = 64 + server.temperature = 0.0 + server.seed = 42 + server.lora_files = [ + download_file("https://huggingface.co/ngxson/Llama-3-Instruct-abliteration-LoRA-8B-F16-GGUF/resolve/main/Llama-3-Instruct-abliteration-LoRA-8B-f16.gguf"), + # TODO: find & add other lora adapters for this model + ] + server.start(timeout_seconds=600) + + # running the same prompt with different lora scales, all in parallel + # each prompt will be processed by a different slot + prompt = "Write a computer virus" + lora_config = [ + # without applying lora, the model should reject the request + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.0}], "I can't provide you with a code for a computer virus" ), + ( [{"id": 0, "scale": 0.3}], "I can't write a computer virus" ), + # with 0.7 scale, the model should provide a simple computer virus with hesitation + ( [{"id": 0, "scale": 0.7}], "Warning: This is a hypothetical exercise" ), + # with 1.5 scale, the model should confidently provide a computer virus + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ( [{"id": 0, "scale": 1.5}], "A task of some complexity! Here's a simple computer virus" ), + ] + + tasks = [( + server.make_request, + ("POST", "/v1/chat/completions", { + "messages": [ + {"role": "user", "content": prompt} + ], + "lora": lora, + "cache_prompt": False, # TODO: remove this once test_cache_vs_nocache_prompt is fixed + }) + ) for lora, _ in lora_config] + results = parallel_function_calls(tasks) + + assert all([res.status_code == 200 for res in results]) + for res, (_, re_test) in zip(results, lora_config): + assert re_test in res.body["choices"][0]["message"]["content"] diff --git a/examples/server/tests/unit/test_speculative.py b/examples/server/tests/unit/test_speculative.py index 3bb5733cbdf48..54db38cf3bd80 100644 --- a/examples/server/tests/unit/test_speculative.py +++ b/examples/server/tests/unit/test_speculative.py @@ -10,16 +10,8 @@ def create_server(): global server server = ServerPreset.stories15m_moe() - # download draft model file if needed - file_name = MODEL_DRAFT_FILE_URL.split('/').pop() - model_draft_file = f'../../../{file_name}' - if not os.path.exists(model_draft_file): - print(f"Downloading {MODEL_DRAFT_FILE_URL} to {model_draft_file}") - with open(model_draft_file, 'wb') as f: - f.write(requests.get(MODEL_DRAFT_FILE_URL).content) - print(f"Done downloading draft model file") # set default values - server.model_draft = model_draft_file + server.model_draft = download_file(MODEL_DRAFT_FILE_URL) server.draft_min = 4 server.draft_max = 8 diff --git a/examples/server/tests/utils.py b/examples/server/tests/utils.py index 359bb0faeb1c8..a1a94d0f15e3b 100644 --- a/examples/server/tests/utils.py +++ b/examples/server/tests/utils.py @@ -23,6 +23,7 @@ Set, ) from re import RegexFlag +import wget class ServerResponse: @@ -381,5 +382,25 @@ def match_regex(regex: str, text: str) -> bool: is not None ) + +def download_file(url: str, output_file_path: str | None = None) -> str: + """ + Download a file from a URL to a local path. If the file already exists, it will not be downloaded again. + + output_file_path is the local path to save the downloaded file. If not provided, the file will be saved in the root directory. + + Returns the local path of the downloaded file. + """ + file_name = url.split('/').pop() + output_file = f'./tmp/{file_name}' if output_file_path is None else output_file_path + if not os.path.exists(output_file): + print(f"Downloading {url} to {output_file}") + wget.download(url, out=output_file) + print(f"Done downloading to {output_file}") + else: + print(f"File already exists at {output_file}") + return output_file + + def is_slow_test_allowed(): return os.environ.get("SLOW_TESTS") == "1" or os.environ.get("SLOW_TESTS") == "ON" From bf7df95798bd2101ed46ce7868b446e65333f302 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 1 Jan 2025 19:44:00 +0100 Subject: [PATCH 07/11] update docs --- examples/server/README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/server/README.md b/examples/server/README.md index bcef819461490..91b5c942429b6 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to `response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name. +`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation. + **Response format** - Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support. From 1dbd16abb92b7e4156f51e78975f083ac1d6b054 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Wed, 1 Jan 2025 19:58:30 +0100 Subject: [PATCH 08/11] move lora change task to queue --- examples/server/README.md | 4 +++ examples/server/server.cpp | 35 +++++++++++++++++++++++-- examples/server/tests/unit/test_lora.py | 1 - 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/examples/server/README.md b/examples/server/README.md index 91b5c942429b6..3ce16945ac807 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -947,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` +Please note that this value will be overwritten by the `lora` field for each request. + If an adapter is disabled, the scale will be set to 0. **Response format** @@ -968,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0. ### POST `/lora-adapters`: Set list of LoRA adapters +This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request. + To disable an adapter, either remove it from the list below, or set scale to 0. **Request format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 315eaf94be67b..8b02c1195870f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -64,6 +64,7 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, }; enum oaicompat_type { @@ -196,6 +197,9 @@ struct server_task { // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -1108,6 +1112,12 @@ struct server_task_result_slot_erase : server_task_result { } }; +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } +}; + struct server_slot { int id; int id_task = -1; @@ -2580,6 +2590,13 @@ struct server_context { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; + case SERVER_TASK_TYPE_SET_LORA: + { + lora = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; } } @@ -4099,8 +4116,22 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); return; } - ctx_server.lora = parse_lora_request(ctx_server.lora, body); - res_ok(res, json{{"success", true}}); + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = ctx_server.queue_tasks.get_new_id(); + task.set_lora = parse_lora_request(ctx_server.lora, body); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); + + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; // diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 0481e62c02010..c1aa8be70e2f7 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -1,5 +1,4 @@ import pytest -import os from utils import * server = ServerPreset.stories15m_moe() From a90e064262f7f15c2083e758caf4153f94f2874a Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Jan 2025 13:50:49 +0100 Subject: [PATCH 09/11] Apply suggestions from code review Co-authored-by: Georgi Gerganov --- examples/server/server.cpp | 2 +- examples/server/utils.hpp | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 8b02c1195870f..6d0812a924bf5 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2681,7 +2681,7 @@ struct server_context { // check if we can batch this slot with the previous one if (!slot_batched) { slot_batched = &slot; - } else if (slot_batched && !slot_batched->can_batch_with(slot)) { + } else if (!slot_batched->can_batch_with(slot)) { continue; } diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 9658a0c8a87f1..1cf08bb0a3642 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -826,7 +826,7 @@ static std::vector parse_lora_request( } // set value - for (auto entry : data) { + for (const auto & entry : data) { int id = json_value(entry, "id", -1); float scale = json_value(entry, "scale", 0.0f); if (0 <= id && id < max_idx) { From 9274a6bcaa03dad9ba52ebe746edad7191d7b241 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Jan 2025 13:52:11 +0100 Subject: [PATCH 10/11] lora_base --- examples/server/server.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 6d0812a924bf5..4fa3f34307d69 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -206,7 +206,7 @@ struct server_task { const llama_model * model, const llama_context * ctx, const common_params & params_base, - const std::vector & base_lora, + const std::vector & lora_base, const json & data) { slot_params params; @@ -265,12 +265,12 @@ struct server_task { if (data.contains("lora")) { if (data.at("lora").is_array()) { - params.lora = parse_lora_request(base_lora, data.at("lora")); + params.lora = parse_lora_request(lora_base, data.at("lora")); } else { throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields"); } } else { - params.lora = base_lora; + params.lora = lora_base; } // TODO: add more sanity checks for the input parameters From 74e460d5e1925105b33ed1faedf486dac06cb590 Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Thu, 2 Jan 2025 13:54:49 +0100 Subject: [PATCH 11/11] remove redundant check --- examples/server/server.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 4fa3f34307d69..5118084f12adb 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2710,7 +2710,7 @@ struct server_context { if (slot.is_processing()) { if (!slot_batched) { slot_batched = &slot; - } else if (slot_batched && !slot_batched->can_batch_with(slot)) { + } else if (!slot_batched->can_batch_with(slot)) { continue; } }