Skip to content

Commit f58c912

Browse files
committed
llama : make loras compatible with repacking
ggml-ci
1 parent 02082f1 commit f58c912

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

src/llama-adapter.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,27 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
247247
}
248248
}
249249

250+
// get extra buffer types of the CPU
251+
std::vector<ggml_backend_buffer_type_t> buft_extra;
252+
{
253+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
254+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
255+
256+
// add the default CPU buffer type which will be used as a fallback if the lora needs to be loaded to an extra buft
257+
buft_extra.emplace_back(ggml_backend_dev_buffer_type(cpu_dev));
258+
259+
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
260+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
261+
262+
if (ggml_backend_dev_get_extra_bufts_fn) {
263+
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
264+
while (extra_bufts && *extra_bufts) {
265+
buft_extra.emplace_back(*extra_bufts);
266+
++extra_bufts;
267+
}
268+
}
269+
}
270+
250271
// add tensors
251272
for (auto & it : ab_map) {
252273
const std::string & name = it.first;
@@ -263,7 +284,20 @@ static void llama_adapter_lora_init_impl(llama_model & model, const char * path_
263284
throw std::runtime_error("LoRA tensor '" + name + "' does not exist in base model (hint: maybe wrong base model?)");
264285
}
265286

266-
ggml_context * dev_ctx = ctx_for_buft(ggml_backend_buffer_get_type(model_tensor->buffer));
287+
auto * buft = ggml_backend_buffer_get_type(model_tensor->buffer);
288+
289+
// do not load loras to extra buffer types (i.e. bufts for repacking) -> use the CPU in that case
290+
for (auto & ex : buft_extra) {
291+
if (ex == buft) {
292+
LLAMA_LOG_WARN("%s: lora for '%s' cannot use buft '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
293+
buft = buft_extra[0];
294+
break;
295+
}
296+
}
297+
298+
LLAMA_LOG_DEBUG("%s: lora for '%s' -> '%s'\n", __func__, model_tensor->name, ggml_backend_buft_name(buft));
299+
300+
ggml_context * dev_ctx = ctx_for_buft(buft);
267301
// validate tensor shape
268302
if (is_token_embd) {
269303
// expect B to be non-transposed, A and B are flipped; see llm_build_inp_embd()

0 commit comments

Comments
 (0)