Skip to content

Commit 0da5d86

Browse files
ngxsonggerganov
andauthored
server : allow using LoRA adapters per-request (#10994)
* slot.can_batch_with * lora per request * test: force disable cache prompt * move can_batch_with check * fix condition * add slow test with llama 8b * update docs * move lora change task to queue * Apply suggestions from code review Co-authored-by: Georgi Gerganov <[email protected]> * lora_base * remove redundant check --------- Co-authored-by: Georgi Gerganov <[email protected]>
1 parent a45433b commit 0da5d86

File tree

8 files changed

+235
-59
lines changed

8 files changed

+235
-59
lines changed

examples/server/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,8 @@ These words will not be included in the completion, so make sure to add them to
452452

453453
`response_fields`: A list of response fields, for example: `"response_fields": ["content", "generation_settings/n_predict"]`. If the specified field is missing, it will simply be omitted from the response without triggering an error. Note that fields with a slash will be unnested; for example, `generation_settings/n_predict` will move the field `n_predict` from the `generation_settings` object to the root of the response and give it a new name.
454454

455+
`lora`: A list of LoRA adapters to be applied to this specific request. Each object in the list must contain `id` and `scale` fields. For example: `[{"id": 0, "scale": 0.5}, {"id": 1, "scale": 1.1}]`. If a LoRA adapter is not specified in the list, its scale will default to `0.0`. Please note that requests with different LoRA configurations will not be batched together, which may result in performance degradation.
456+
455457
**Response format**
456458

457459
- Note: In streaming mode (`stream`), only `content`, `tokens` and `stop` will be returned until end of completion. Responses are sent using the [Server-sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html) standard. Note: the browser's `EventSource` interface cannot be used due to its lack of `POST` request support.
@@ -945,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo
945947

946948
By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
947949

950+
Please note that this value will be overwritten by the `lora` field for each request.
951+
948952
If an adapter is disabled, the scale will be set to 0.
949953

950954
**Response format**
@@ -966,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.
966970

967971
### POST `/lora-adapters`: Set list of LoRA adapters
968972

973+
This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.
974+
969975
To disable an adapter, either remove it from the list below, or set scale to 0.
970976

971977
**Request format**

examples/server/server.cpp

Lines changed: 76 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ struct slot_params {
9898
int64_t t_max_prompt_ms = -1; // TODO: implement
9999
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
100100

101+
std::vector<common_lora_adapter_container> lora;
102+
101103
std::vector<std::string> antiprompt;
102104
std::vector<std::string> response_fields;
103105
bool timings_per_token = false;
@@ -120,6 +122,11 @@ struct slot_params {
120122
samplers.emplace_back(common_sampler_type_to_str(sampler));
121123
}
122124

125+
json lora = json::array();
126+
for (size_t i = 0; i < this->lora.size(); ++i) {
127+
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
128+
}
129+
123130
return json {
124131
{"n_predict", n_predict}, // Server configured n_predict
125132
{"seed", sampling.seed},
@@ -160,6 +167,7 @@ struct slot_params {
160167
{"speculative.p_min", speculative.p_min},
161168
{"timings_per_token", timings_per_token},
162169
{"post_sampling_probs", post_sampling_probs},
170+
{"lora", lora},
163171
};
164172
}
165173
};
@@ -189,12 +197,16 @@ struct server_task {
189197
// used by SERVER_TASK_TYPE_METRICS
190198
bool metrics_reset_bucket = false;
191199

200+
// used by SERVER_TASK_TYPE_SET_LORA
201+
std::vector<common_lora_adapter_container> set_lora;
202+
192203
server_task(server_task_type type) : type(type) {}
193204

194205
static slot_params params_from_json_cmpl(
195206
const llama_model * model,
196207
const llama_context * ctx,
197208
const common_params & params_base,
209+
const std::vector<common_lora_adapter_container> & lora_base,
198210
const json & data) {
199211
slot_params params;
200212

@@ -251,6 +263,16 @@ struct server_task {
251263
params.speculative.n_min = std::max(params.speculative.n_min, 2);
252264
params.speculative.n_max = std::max(params.speculative.n_max, 0);
253265

266+
if (data.contains("lora")) {
267+
if (data.at("lora").is_array()) {
268+
params.lora = parse_lora_request(lora_base, data.at("lora"));
269+
} else {
270+
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
271+
}
272+
} else {
273+
params.lora = lora_base;
274+
}
275+
254276
// TODO: add more sanity checks for the input parameters
255277

256278
if (params.sampling.penalty_last_n < -1) {
@@ -1110,6 +1132,8 @@ struct server_slot {
11101132

11111133
common_speculative * spec = nullptr;
11121134

1135+
std::vector<common_lora_adapter_container> lora;
1136+
11131137
// the index relative to completion multi-task request
11141138
size_t index = 0;
11151139

@@ -1191,6 +1215,11 @@ struct server_slot {
11911215
return task_type == SERVER_TASK_TYPE_EMBEDDING || task_type == SERVER_TASK_TYPE_RERANK;
11921216
}
11931217

1218+
bool can_batch_with(server_slot & other_slot) {
1219+
return is_non_causal() == other_slot.is_non_causal()
1220+
&& are_lora_equal(lora, other_slot.lora);
1221+
}
1222+
11941223
bool has_budget(const common_params & global_params) {
11951224
if (params.n_predict == -1 && global_params.n_predict == -1) {
11961225
return true; // limitless
@@ -1600,7 +1629,7 @@ struct server_context {
16001629

16011630
llama_model * model = nullptr;
16021631
llama_context * ctx = nullptr;
1603-
std::vector<common_lora_adapter_container> loras;
1632+
std::vector<common_lora_adapter_container> lora;
16041633

16051634
llama_model * model_dft = nullptr;
16061635
llama_context_params cparams_dft;
@@ -1667,7 +1696,7 @@ struct server_context {
16671696

16681697
model = llama_init.model;
16691698
ctx = llama_init.context;
1670-
loras = llama_init.lora_adapters;
1699+
lora = llama_init.lora_adapters;
16711700

16721701
if (model == nullptr) {
16731702
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@@ -1866,6 +1895,12 @@ struct server_context {
18661895
slot.params = std::move(task.params);
18671896
slot.prompt_tokens = std::move(task.prompt_tokens);
18681897

1898+
if (!are_lora_equal(task.params.lora, slot.lora)) {
1899+
// if lora is changed, we cannot reuse cached tokens
1900+
slot.cache_tokens.clear();
1901+
slot.lora = std::move(task.params.lora);
1902+
}
1903+
18691904
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
18701905

18711906
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -2557,7 +2592,7 @@ struct server_context {
25572592
} break;
25582593
case SERVER_TASK_TYPE_SET_LORA:
25592594
{
2560-
common_lora_adapters_apply(ctx, loras);
2595+
lora = std::move(task.set_lora);
25612596
auto res = std::make_unique<server_task_result_apply_lora>();
25622597
res->id = task.id;
25632598
queue_results.send(std::move(res));
@@ -2634,12 +2669,22 @@ struct server_context {
26342669
// start populating the batch for this iteration
26352670
common_batch_clear(batch);
26362671

2672+
// track if given slot can be batched with slots already in the batch
2673+
server_slot * slot_batched = nullptr;
2674+
26372675
// frist, add sampled tokens from any ongoing sequences
26382676
for (auto & slot : slots) {
26392677
if (slot.state != SLOT_STATE_GENERATING) {
26402678
continue;
26412679
}
26422680

2681+
// check if we can batch this slot with the previous one
2682+
if (!slot_batched) {
2683+
slot_batched = &slot;
2684+
} else if (!slot_batched->can_batch_with(slot)) {
2685+
continue;
2686+
}
2687+
26432688
slot.i_batch = batch.n_tokens;
26442689

26452690
common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true);
@@ -2658,15 +2703,18 @@ struct server_context {
26582703
int32_t n_batch = llama_n_batch(ctx);
26592704
int32_t n_ubatch = llama_n_ubatch(ctx);
26602705

2661-
// track if this is an embedding or non-embedding batch
2662-
// if we've added sampled tokens above, we are in non-embedding mode
2663-
// -1: none, 0: non-embedding, 1: embedding
2664-
// TODO: make enum
2665-
int32_t batch_type = batch.n_tokens > 0 ? 0 : -1;
2666-
26672706
// next, batch any pending prompts without exceeding n_batch
26682707
if (params_base.cont_batching || batch.n_tokens == 0) {
26692708
for (auto & slot : slots) {
2709+
// check if we can batch this slot with the previous one
2710+
if (slot.is_processing()) {
2711+
if (!slot_batched) {
2712+
slot_batched = &slot;
2713+
} else if (!slot_batched->can_batch_with(slot)) {
2714+
continue;
2715+
}
2716+
}
2717+
26702718
// this slot still has a prompt to be processed
26712719
if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_STARTED) {
26722720
auto & prompt_tokens = slot.prompt_tokens;
@@ -2827,14 +2875,6 @@ struct server_context {
28272875
}
28282876
}
28292877

2830-
// check that we are in the right batch_type, if not defer the slot
2831-
int slot_type = slot.is_non_causal();
2832-
if (batch_type == -1) {
2833-
batch_type = slot_type;
2834-
} else if (batch_type != slot_type) {
2835-
continue;
2836-
}
2837-
28382878
// keep only the common part
28392879
if (!llama_kv_cache_seq_rm(ctx, slot.id, slot.n_past, -1)) {
28402880
// could not partially delete (likely using a non-Transformer model)
@@ -2902,8 +2942,12 @@ struct server_context {
29022942

29032943
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
29042944

2905-
// make sure we're in the right embedding mode
2906-
llama_set_embeddings(ctx, batch_type == 1);
2945+
if (slot_batched) {
2946+
// make sure we're in the right embedding mode
2947+
llama_set_embeddings(ctx, slot_batched->is_non_causal());
2948+
// apply lora, only need to do it once per batch
2949+
common_lora_adapters_apply(ctx, slot_batched->lora);
2950+
}
29072951

29082952
// process the created batch of tokens
29092953
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
@@ -3623,7 +3667,12 @@ int main(int argc, char ** argv) {
36233667
task.index = i;
36243668

36253669
task.prompt_tokens = std::move(tokenized_prompts[i]);
3626-
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
3670+
task.params = server_task::params_from_json_cmpl(
3671+
ctx_server.model,
3672+
ctx_server.ctx,
3673+
ctx_server.params_base,
3674+
ctx_server.lora,
3675+
data);
36273676
task.id_selected_slot = json_value(data, "id_slot", -1);
36283677

36293678
// OAI-compat
@@ -4049,8 +4098,8 @@ int main(int argc, char ** argv) {
40494098

40504099
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
40514100
json result = json::array();
4052-
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
4053-
auto & lora = ctx_server.loras[i];
4101+
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
4102+
auto & lora = ctx_server.lora[i];
40544103
result.push_back({
40554104
{"id", i},
40564105
{"path", lora.path},
@@ -4062,27 +4111,14 @@ int main(int argc, char ** argv) {
40624111
};
40634112

40644113
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
4065-
const std::vector<json> body = json::parse(req.body);
4066-
int max_idx = ctx_server.loras.size();
4067-
4068-
// clear existing value
4069-
for (auto & lora : ctx_server.loras) {
4070-
lora.scale = 0.0f;
4071-
}
4072-
4073-
// set value
4074-
for (auto entry : body) {
4075-
int id = entry.at("id");
4076-
float scale = entry.at("scale");
4077-
if (0 <= id && id < max_idx) {
4078-
ctx_server.loras[id].scale = scale;
4079-
} else {
4080-
throw std::runtime_error("invalid adapter id");
4081-
}
4114+
const json body = json::parse(req.body);
4115+
if (!body.is_array()) {
4116+
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
4117+
return;
40824118
}
4083-
40844119
server_task task(SERVER_TASK_TYPE_SET_LORA);
40854120
task.id = ctx_server.queue_tasks.get_new_id();
4121+
task.set_lora = parse_lora_request(ctx_server.lora, body);
40864122
ctx_server.queue_results.add_waiting_task_id(task.id);
40874123
ctx_server.queue_tasks.post(task);
40884124

examples/server/tests/README.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
4444
DEBUG=1 ./tests.sh -s -v -x
4545
```
4646

47+
To run single test unit:
48+
49+
```shell
50+
./tests.sh unit/test_{name of test case here}.py -v -x
51+
```
52+
4753
Hint: You can compile and run test in single command, useful for local developement:
4854
4955
```shell

examples/server/tests/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ numpy~=1.26.4
55
openai~=1.55.3
66
prometheus-client~=0.20.0
77
requests~=2.32.3
8+
wget~=3.2

0 commit comments

Comments
 (0)