Skip to content

Commit 1dbd16a

Browse files
committed
move lora change task to queue
1 parent bf7df95 commit 1dbd16a

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

examples/server/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -947,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo
947947

948948
By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
949949

950+
Please note that this value will be overwritten by the `lora` field for each request.
951+
950952
If an adapter is disabled, the scale will be set to 0.
951953

952954
**Response format**
@@ -968,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.
968970

969971
### POST `/lora-adapters`: Set list of LoRA adapters
970972

973+
This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.
974+
971975
To disable an adapter, either remove it from the list below, or set scale to 0.
972976

973977
**Request format**

examples/server/server.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ enum server_task_type {
6464
SERVER_TASK_TYPE_SLOT_SAVE,
6565
SERVER_TASK_TYPE_SLOT_RESTORE,
6666
SERVER_TASK_TYPE_SLOT_ERASE,
67+
SERVER_TASK_TYPE_SET_LORA,
6768
};
6869

6970
enum oaicompat_type {
@@ -196,6 +197,9 @@ struct server_task {
196197
// used by SERVER_TASK_TYPE_METRICS
197198
bool metrics_reset_bucket = false;
198199

200+
// used by SERVER_TASK_TYPE_SET_LORA
201+
std::vector<common_lora_adapter_container> set_lora;
202+
199203
server_task(server_task_type type) : type(type) {}
200204

201205
static slot_params params_from_json_cmpl(
@@ -1108,6 +1112,12 @@ struct server_task_result_slot_erase : server_task_result {
11081112
}
11091113
};
11101114

1115+
struct server_task_result_apply_lora : server_task_result {
1116+
virtual json to_json() override {
1117+
return json {{ "success", true }};
1118+
}
1119+
};
1120+
11111121
struct server_slot {
11121122
int id;
11131123
int id_task = -1;
@@ -2580,6 +2590,13 @@ struct server_context {
25802590
res->n_erased = n_erased;
25812591
queue_results.send(std::move(res));
25822592
} break;
2593+
case SERVER_TASK_TYPE_SET_LORA:
2594+
{
2595+
lora = std::move(task.set_lora);
2596+
auto res = std::make_unique<server_task_result_apply_lora>();
2597+
res->id = task.id;
2598+
queue_results.send(std::move(res));
2599+
} break;
25832600
}
25842601
}
25852602

@@ -4099,8 +4116,22 @@ int main(int argc, char ** argv) {
40994116
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
41004117
return;
41014118
}
4102-
ctx_server.lora = parse_lora_request(ctx_server.lora, body);
4103-
res_ok(res, json{{"success", true}});
4119+
server_task task(SERVER_TASK_TYPE_SET_LORA);
4120+
task.id = ctx_server.queue_tasks.get_new_id();
4121+
task.set_lora = parse_lora_request(ctx_server.lora, body);
4122+
ctx_server.queue_results.add_waiting_task_id(task.id);
4123+
ctx_server.queue_tasks.post(task);
4124+
4125+
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
4126+
ctx_server.queue_results.remove_waiting_task_id(task.id);
4127+
4128+
if (result->is_error()) {
4129+
res_error(res, result->to_json());
4130+
return;
4131+
}
4132+
4133+
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
4134+
res_ok(res, result->to_json());
41044135
};
41054136

41064137
//

examples/server/tests/unit/test_lora.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import pytest
2-
import os
32
from utils import *
43

54
server = ServerPreset.stories15m_moe()

0 commit comments

Comments
 (0)