Skip to content
Merged
21 changes: 10 additions & 11 deletions lora.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ struct LoraModel : public GGMLRunner {
return "lora";
}

bool load_from_file(bool filter_tensor = false) {
bool load_from_file(bool filter_tensor = false, int n_threads = 0) {
LOG_INFO("loading LoRA from '%s'", file_path.c_str());

if (load_failed) {
Expand All @@ -131,15 +131,14 @@ struct LoraModel : public GGMLRunner {
// LOG_INFO("skipping LoRA tesnor '%s'", name.c_str());
return true;
}
// LOG_INFO("lora_tensor %s", name.c_str());
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
if (name.find(type_fingerprints[i]) != std::string::npos) {
type = (lora_t)i;
break;
}
}

if (dry_run) {
for (int i = 0; i < LORA_TYPE_COUNT; i++) {
if (name.find(type_fingerprints[i]) != std::string::npos) {
type = (lora_t)i;
break;
}
}
struct ggml_tensor* real = ggml_new_tensor(params_ctx,
tensor_storage.type,
tensor_storage.n_dims,
Expand All @@ -153,11 +152,11 @@ struct LoraModel : public GGMLRunner {
return true;
};

model_loader.load_tensors(on_new_tensor_cb);
model_loader.load_tensors(on_new_tensor_cb, 1);
alloc_params_buffer();
// exit(0);

dry_run = false;
model_loader.load_tensors(on_new_tensor_cb);
model_loader.load_tensors(on_new_tensor_cb, n_threads);

LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str());

Expand Down
Loading
Loading