Skip to content

Commit b5c3eaa

Browse files
committed
Added --lora-layer-range option
1 parent 494c589 commit b5c3eaa

File tree

5 files changed

+58
-15
lines changed

5 files changed

+58
-15
lines changed

common/arg.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2472,6 +2472,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
24722472
}
24732473
// we define this arg on both COMMON and EXPORT_LORA, so when showing help message of export-lora, it will be categorized as "example-specific" arg
24742474
).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}));
2475+
add_opt(common_arg(
2476+
{"--lora-layer-range"}, "START", "END",
2477+
"layer range to apply the lora(s) to, start and end inclusive",
2478+
[](common_params & params, const std::string & start, const std::string & end) {
2479+
params.lora_layer_start = std::stoi(start);
2480+
params.lora_layer_end = std::stoi(end);
2481+
}
2482+
));
24752483
add_opt(common_arg(
24762484
{"--control-vector"}, "FNAME",
24772485
"add a control vector\nnote: this argument can be repeated to add multiple control vectors",

common/common.cpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -982,18 +982,23 @@ struct common_init_result common_init_from_params(common_params & params) {
982982
}
983983

984984
// load and optionally apply lora adapters
985-
for (auto & la : params.lora_adapters) {
986-
llama_adapter_lora_ptr lora;
987-
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
988-
if (lora == nullptr) {
989-
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
990-
llama_free(lctx);
991-
llama_model_free(model);
992-
return iparams;
993-
}
985+
if (!params.lora_adapters.empty()) {
986+
if (params.lora_layer_start <= 0) params.lora_layer_start = 1;
987+
if (params.lora_layer_end <= 0) params.lora_layer_end = llama_model_n_layer(model);
988+
989+
for (auto & la : params.lora_adapters) {
990+
llama_adapter_lora_ptr lora;
991+
lora.reset(llama_adapter_lora_init(model, la.path.c_str()));
992+
if (lora == nullptr) {
993+
LOG_ERR("%s: failed to apply lora adapter '%s'\n", __func__, la.path.c_str());
994+
llama_free(lctx);
995+
llama_model_free(model);
996+
return iparams;
997+
}
994998

995-
la.ptr = lora.get();
996-
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
999+
la.ptr = lora.get();
1000+
iparams.lora.emplace_back(std::move(lora)); // copy to list of loaded adapters
1001+
}
9971002
}
9981003

9991004
if (!params.lora_init_without_apply) {

common/common.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ struct common_params {
296296
int32_t verbosity = 0;
297297
int32_t control_vector_layer_start = -1; // layer range for control vector
298298
int32_t control_vector_layer_end = -1; // layer range for control vector
299+
int32_t lora_layer_start = -1; // layer range for lora
300+
int32_t lora_layer_end = -1; // layer range for lora
299301
bool offline = false;
300302

301303
int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.

include/llama.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,9 +544,12 @@ extern "C" {
544544
//
545545

546546
// Load a LoRA adapter from file
547+
// il_start and il_end are the layer range the lora should apply to (both inclusive)
547548
LLAMA_API struct llama_adapter_lora * llama_adapter_lora_init(
548549
struct llama_model * model,
549-
const char * path_lora);
550+
const char * path_lora,
551+
int32_t il_start,
552+
int32_t il_end);
550553

551554
// Manually free a LoRA adapter
552555
// Note: loaded adapters will be free when the associated model is deleted

src/llama-adapter.cpp

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,12 @@ llama_adapter_lora_weight * llama_adapter_lora::get_weight(ggml_tensor * w) {
145145
return nullptr;
146146
}
147147

148-
static void llama_adapter_lora_init_impl(llama_model & model, const char * path_lora, llama_adapter_lora & adapter) {
148+
static void llama_adapter_lora_init_impl(
149+
llama_model & model,
150+
const char * path_lora,
151+
llama_adapter_lora & adapter,
152+
int32_t il_start,
153+
int32_t il_end) {
149154
LLAMA_LOG_INFO("%s: loading lora adapter from '%s' ...\n", __func__, path_lora);
150155

151156
ggml_context * ctx_init;
@@ -224,6 +229,22 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
224229

225230
for (ggml_tensor * cur = ggml_get_first_tensor(ctx.get()); cur; cur = ggml_get_next_tensor(ctx.get(), cur)) {
226231
std::string name(cur->name);
232+
233+
// check if this tensor has a layer number and is outside our range
234+
size_t blk_pos = name.find("blk.");
235+
if (blk_pos != std::string::npos) {
236+
size_t start = blk_pos + 4; // skip "blk."
237+
size_t end = name.find('.', start);
238+
try {
239+
int layer_num = std::stoi(name.substr(start, end - start));
240+
if (layer_num < il_start || layer_num > il_end) {
241+
continue; // skip this tensor
242+
}
243+
} catch (const std::exception & err) {
244+
LLAMA_LOG_ERROR("%s: failed to parse layer number from tensor '%s': %s\n", __func__, name.c_str(), err.what());
245+
}
246+
}
247+
227248
if (str_endswith(name, ".lora_a")) {
228249
replace_all(name, ".lora_a", "");
229250
if (ab_map.find(name) == ab_map.end()) {
@@ -368,11 +389,15 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
368389
LLAMA_LOG_INFO("%s: loaded %zu tensors from lora file\n", __func__, adapter.ab_map.size()*2);
369390
}
370391

371-
llama_adapter_lora * llama_adapter_lora_init(llama_model * model, const char * path_lora) {
392+
llama_adapter_lora * llama_adapter_lora_init(
393+
llama_model * model,
394+
const char * path_lora,
395+
int32_t il_start,
396+
int32_t il_end) {
372397
llama_adapter_lora * adapter = new llama_adapter_lora();
373398

374399
try {
375-
llama_adapter_lora_init_impl(*model, path_lora, *adapter);
400+
llama_adapter_lora_init_impl(*model, path_lora, *adapter, il_start, il_end);
376401
return adapter;
377402
} catch (const std::exception & err) {
378403
LLAMA_LOG_ERROR("%s: failed to apply lora adapter: %s\n", __func__, err.what());

0 commit comments

Comments
 (0)