Skip to content

Commit 9d84127

Browse files
committed
lora per request
1 parent 2ba6efc commit 9d84127

File tree

3 files changed

+125
-54
lines changed

3 files changed

+125
-54
lines changed

examples/server/server.cpp

Lines changed: 50 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ enum server_task_type {
6464
SERVER_TASK_TYPE_SLOT_SAVE,
6565
SERVER_TASK_TYPE_SLOT_RESTORE,
6666
SERVER_TASK_TYPE_SLOT_ERASE,
67-
SERVER_TASK_TYPE_SET_LORA,
6867
};
6968

7069
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
@@ -91,6 +90,8 @@ struct slot_params {
9190
int64_t t_max_prompt_ms = -1; // TODO: implement
9291
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
9392

93+
std::vector<common_lora_adapter_container> lora;
94+
9495
std::vector<std::string> antiprompt;
9596
std::vector<std::string> response_fields;
9697
bool timings_per_token = false;
@@ -114,6 +115,11 @@ struct slot_params {
114115
samplers.emplace_back(common_sampler_type_to_str(sampler));
115116
}
116117

118+
json lora = json::array();
119+
for (size_t i = 0; i < this->lora.size(); ++i) {
120+
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
121+
}
122+
117123
return json {
118124
{"n_predict", n_predict}, // Server configured n_predict
119125
{"seed", sampling.seed},
@@ -154,6 +160,7 @@ struct slot_params {
154160
{"speculative.p_min", speculative.p_min},
155161
{"timings_per_token", timings_per_token},
156162
{"post_sampling_probs", post_sampling_probs},
163+
{"lora", lora},
157164
};
158165
}
159166
};
@@ -189,6 +196,7 @@ struct server_task {
189196
const llama_model * model,
190197
const llama_context * ctx,
191198
const common_params & params_base,
199+
const std::vector<common_lora_adapter_container> & base_lora,
192200
const json & data) {
193201
slot_params params;
194202

@@ -245,6 +253,16 @@ struct server_task {
245253
params.speculative.n_min = std::max(params.speculative.n_min, 2);
246254
params.speculative.n_max = std::max(params.speculative.n_max, 0);
247255

256+
if (data.contains("lora")) {
257+
if (data.at("lora").is_array()) {
258+
params.lora = parse_lora_request(base_lora, data.at("lora"));
259+
} else {
260+
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
261+
}
262+
} else {
263+
params.lora = base_lora;
264+
}
265+
248266
// TODO: add more sanity checks for the input parameters
249267

250268
if (params.sampling.penalty_last_n < -1) {
@@ -989,12 +1007,6 @@ struct server_task_result_slot_erase : server_task_result {
9891007
}
9901008
};
9911009

992-
struct server_task_result_apply_lora : server_task_result {
993-
virtual json to_json() override {
994-
return json {{ "success", true }};
995-
}
996-
};
997-
9981010
struct server_slot {
9991011
int id;
10001012
int id_task = -1;
@@ -1009,6 +1021,8 @@ struct server_slot {
10091021

10101022
common_speculative * spec = nullptr;
10111023

1024+
std::vector<common_lora_adapter_container> lora;
1025+
10121026
// the index relative to completion multi-task request
10131027
size_t index = 0;
10141028

@@ -1091,7 +1105,8 @@ struct server_slot {
10911105
}
10921106

10931107
bool can_batch_with(server_slot & other_slot) {
1094-
return is_non_causal() == other_slot.is_non_causal();
1108+
return is_non_causal() == other_slot.is_non_causal()
1109+
&& are_lora_equal(lora, other_slot.lora);
10951110
}
10961111

10971112
bool has_budget(const common_params & global_params) {
@@ -1503,7 +1518,7 @@ struct server_context {
15031518

15041519
llama_model * model = nullptr;
15051520
llama_context * ctx = nullptr;
1506-
std::vector<common_lora_adapter_container> loras;
1521+
std::vector<common_lora_adapter_container> lora;
15071522

15081523
llama_model * model_dft = nullptr;
15091524
llama_context_params cparams_dft;
@@ -1570,7 +1585,7 @@ struct server_context {
15701585

15711586
model = llama_init.model;
15721587
ctx = llama_init.context;
1573-
loras = llama_init.lora_adapters;
1588+
lora = llama_init.lora_adapters;
15741589

15751590
if (model == nullptr) {
15761591
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
@@ -1776,6 +1791,12 @@ struct server_context {
17761791
slot.params = std::move(task.params);
17771792
slot.prompt_tokens = std::move(task.prompt_tokens);
17781793

1794+
if (!are_lora_equal(task.params.lora, slot.lora)) {
1795+
// if lora is changed, we cannot reuse cached tokens
1796+
slot.cache_tokens.clear();
1797+
slot.lora = std::move(task.params.lora);
1798+
}
1799+
17791800
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
17801801

17811802
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
@@ -2465,13 +2486,6 @@ struct server_context {
24652486
res->n_erased = n_erased;
24662487
queue_results.send(std::move(res));
24672488
} break;
2468-
case SERVER_TASK_TYPE_SET_LORA:
2469-
{
2470-
common_lora_adapters_apply(ctx, loras);
2471-
auto res = std::make_unique<server_task_result_apply_lora>();
2472-
res->id = task.id;
2473-
queue_results.send(std::move(res));
2474-
} break;
24752489
}
24762490
}
24772491

@@ -2808,8 +2822,12 @@ struct server_context {
28082822

28092823
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
28102824

2811-
// make sure we're in the right embedding mode
2812-
llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
2825+
if (slot_batched) {
2826+
// make sure we're in the right embedding mode
2827+
llama_set_embeddings(ctx, slot_batched->is_non_causal());
2828+
// apply lora, only need to do it once per batch
2829+
common_lora_adapters_apply(ctx, slot_batched->lora);
2830+
}
28132831

28142832
// process the created batch of tokens
28152833
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
@@ -3530,7 +3548,12 @@ int main(int argc, char ** argv) {
35303548
task.index = i;
35313549

35323550
task.prompt_tokens = std::move(tokenized_prompts[i]);
3533-
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
3551+
task.params = server_task::params_from_json_cmpl(
3552+
ctx_server.model,
3553+
ctx_server.ctx,
3554+
ctx_server.params_base,
3555+
ctx_server.lora,
3556+
data);
35343557
task.id_selected_slot = json_value(data, "id_slot", -1);
35353558

35363559
// OAI-compat
@@ -3944,8 +3967,8 @@ int main(int argc, char ** argv) {
39443967

39453968
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
39463969
json result = json::array();
3947-
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
3948-
auto & lora = ctx_server.loras[i];
3970+
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
3971+
auto & lora = ctx_server.lora[i];
39493972
result.push_back({
39503973
{"id", i},
39513974
{"path", lora.path},
@@ -3957,40 +3980,13 @@ int main(int argc, char ** argv) {
39573980
};
39583981

39593982
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
3960-
const std::vector<json> body = json::parse(req.body);
3961-
int max_idx = ctx_server.loras.size();
3962-
3963-
// clear existing value
3964-
for (auto & lora : ctx_server.loras) {
3965-
lora.scale = 0.0f;
3966-
}
3967-
3968-
// set value
3969-
for (auto entry : body) {
3970-
int id = entry.at("id");
3971-
float scale = entry.at("scale");
3972-
if (0 <= id && id < max_idx) {
3973-
ctx_server.loras[id].scale = scale;
3974-
} else {
3975-
throw std::runtime_error("invalid adapter id");
3976-
}
3977-
}
3978-
3979-
server_task task(SERVER_TASK_TYPE_SET_LORA);
3980-
task.id = ctx_server.queue_tasks.get_new_id();
3981-
ctx_server.queue_results.add_waiting_task_id(task.id);
3982-
ctx_server.queue_tasks.post(task);
3983-
3984-
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
3985-
ctx_server.queue_results.remove_waiting_task_id(task.id);
3986-
3987-
if (result->is_error()) {
3988-
res_error(res, result->to_json());
3983+
const json body = json::parse(req.body);
3984+
if (!body.is_array()) {
3985+
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
39893986
return;
39903987
}
3991-
3992-
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
3993-
res_ok(res, result->to_json());
3988+
ctx_server.lora = parse_lora_request(ctx_server.lora, body);
3989+
res_ok(res, json{{"success", true}});
39943990
};
39953991

39963992
//

examples/server/tests/unit/test_lora.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,37 @@ def test_lora(scale: float, re_content: str):
4040
assert res.status_code == 200
4141
assert match_regex(re_content, res.body["content"])
4242

43+
44+
def test_lora_per_request():
45+
global server
46+
server.n_slots = 4
47+
server.start()
48+
49+
# running the same prompt with different lora scales, all in parallel
50+
# each prompt will be processed by a different slot
51+
prompt = "Look in thy glass"
52+
lora_config = [
53+
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
54+
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
55+
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
56+
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
57+
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
58+
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
59+
]
60+
# FIXME: tesing with scale between 0.0 and 1.0 (i.e. 0.2, 0.5, 0.7) produces unreliable results
61+
62+
tasks = [(
63+
server.make_request,
64+
("POST", "/completion", {
65+
"prompt": prompt,
66+
"lora": lora,
67+
"seed": 42,
68+
"temperature": 0.0,
69+
})
70+
) for lora, re_test in lora_config]
71+
results = parallel_function_calls(tasks)
72+
73+
print(results)
74+
assert all([res.status_code == 200 for res in results])
75+
for res, (_, re_test) in zip(results, lora_config):
76+
assert match_regex(re_test, res.body["content"])

examples/server/utils.hpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,3 +771,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
771771

772772
return cur;
773773
}
774+
775+
static bool are_lora_equal(
776+
const std::vector<common_lora_adapter_container> & l1,
777+
const std::vector<common_lora_adapter_container> & l2) {
778+
if (l1.size() != l2.size()) {
779+
return false;
780+
}
781+
for (size_t i = 0; i < l1.size(); ++i) {
782+
// we don't check lora.path to reduce the time complexity
783+
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
784+
return false;
785+
}
786+
}
787+
return true;
788+
}
789+
790+
// parse lora config from JSON request, returned a copy of base_lora with updated scale
791+
static std::vector<common_lora_adapter_container> parse_lora_request(
792+
const std::vector<common_lora_adapter_container> & base_lora,
793+
const json & data) {
794+
std::vector<common_lora_adapter_container> lora(base_lora);
795+
int max_idx = lora.size();
796+
797+
// clear existing value
798+
for (auto & entry : lora) {
799+
entry.scale = 0.0f;
800+
}
801+
802+
// set value
803+
for (auto entry : data) {
804+
int id = json_value(entry, "id", -1);
805+
float scale = json_value(entry, "scale", 0.0f);
806+
if (0 <= id && id < max_idx) {
807+
lora[id].scale = scale;
808+
} else {
809+
throw std::runtime_error("invalid adapter id");
810+
}
811+
}
812+
813+
return lora;
814+
}

0 commit comments

Comments
 (0)