Skip to content

Commit 8b86d0e

Browse files
slarenNexesenex
authored andcommitted
llama : allow other bufts when overriding to CPU, add --no-repack option (ggml-org#14990)
1 parent c39cbf1 commit 8b86d0e

File tree

5 files changed

+41
-21
lines changed

5 files changed

+41
-21
lines changed

common/arg.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2094,6 +2094,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
20942094
params.no_kv_offload = true;
20952095
}
20962096
).set_env("LLAMA_ARG_NO_KV_OFFLOAD"));
2097+
add_opt(common_arg(
2098+
{"-nr", "--no-repack"},
2099+
"disable weight repacking",
2100+
[](common_params & params) {
2101+
params.no_extra_bufts = true;
2102+
}
2103+
).set_env("LLAMA_ARG_NO_REPACK"));
20972104
add_opt(common_arg(
20982105
{"-ctk", "--cache-type-k"}, "TYPE",
20992106
string_format(

common/common.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1130,8 +1130,11 @@ struct llama_model_params common_model_params_to_llama(common_params & params) {
11301130
mparams.use_mmap = params.use_mmap;
11311131
mparams.use_mlock = params.use_mlock;
11321132
mparams.check_tensors = params.check_tensors;
1133+
11331134
mparams.requested_n_ctx = params.n_ctx;
11341135

1136+
mparams.use_extra_bufts = !params.no_extra_bufts;
1137+
11351138
if (params.kv_overrides.empty()) {
11361139
mparams.kv_overrides = NULL;
11371140
} else {

common/common.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,7 @@ struct common_params {
354354
bool warmup = true; // warmup run
355355
bool check_tensors = false; // validate tensor data
356356
bool no_op_offload = false; // globally disable offload host tensor operations to device
357+
bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking)
357358

358359
bool single_turn = false; // single turn chat conversation
359360

include/llama.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,10 +342,11 @@ extern "C" {
342342
uint32_t requested_n_ctx;
343343

344344
// Keep the booleans together to avoid misalignment during copy-by-value.
345-
bool vocab_only; // only load the vocabulary, no weights
346-
bool use_mmap; // use mmap if possible
347-
bool use_mlock; // force system to keep model in RAM
348-
bool check_tensors; // validate model tensor data
345+
bool vocab_only; // only load the vocabulary, no weights
346+
bool use_mmap; // use mmap if possible
347+
bool use_mlock; // force system to keep model in RAM
348+
bool check_tensors; // validate model tensor data
349+
bool use_extra_bufts; // use extra buffer types (used for weight repacking)
349350
};
350351

351352
// NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations

src/llama-model.cpp

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara
297297
}
298298

299299
// CPU: ACCEL -> GPU host -> CPU extra -> CPU
300-
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices) {
300+
static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & devices, bool use_extra_bufts) {
301301
buft_list_t buft_list;
302302

303303
// add ACCEL buffer types
@@ -326,21 +326,22 @@ static buft_list_t make_cpu_buft_list(const std::vector<ggml_backend_dev_t> & de
326326
}
327327
}
328328

329-
// add extra buffer types, only if no GPU device is present
330-
// ref: https://github.com/ggml-org/llama.cpp/issues/12481#issuecomment-2743136094
331-
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
332-
if (cpu_dev == nullptr) {
333-
throw std::runtime_error(format("%s: no CPU backend found", __func__));
334-
}
329+
// add extra buffer types
330+
if (use_extra_bufts) {
331+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
332+
if (cpu_dev == nullptr) {
333+
throw std::runtime_error(format("%s: no CPU backend found", __func__));
334+
}
335335

336-
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
337-
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
338-
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
339-
if (ggml_backend_dev_get_extra_bufts_fn) {
340-
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
341-
while (extra_bufts && *extra_bufts) {
342-
buft_list.emplace_back(cpu_dev, *extra_bufts);
343-
++extra_bufts;
336+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
337+
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
338+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
339+
if (ggml_backend_dev_get_extra_bufts_fn) {
340+
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
341+
while (extra_bufts && *extra_bufts) {
342+
buft_list.emplace_back(cpu_dev, *extra_bufts);
343+
++extra_bufts;
344+
}
344345
}
345346
}
346347

@@ -1846,7 +1847,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
18461847
LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false");
18471848

18481849
// build a list of buffer types for the CPU and GPU devices
1849-
pimpl->cpu_buft_list = make_cpu_buft_list(devices);
1850+
pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts);
18501851
for (auto * dev : devices) {
18511852
buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split);
18521853
// add CPU buffer types as a fallback
@@ -2367,7 +2368,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
23672368
for (const auto * overrides = ml.tensor_buft_overrides; overrides->pattern != nullptr; ++overrides) {
23682369
std::regex pattern(overrides->pattern);
23692370
if (std::regex_search(tensor_name, pattern)) {
2370-
buft = overrides->buft;
2371+
if (overrides->buft == ggml_backend_cpu_buffer_type()) {
2372+
// when overriding to a CPU buffer, consider the extra buffer types
2373+
buft = select_weight_buft(hparams, t_meta, op, pimpl->cpu_buft_list);
2374+
} else {
2375+
buft = overrides->buft;
2376+
}
2377+
23712378
LLAMA_LOG_DEBUG("tensor %s (%zu MiB %s) buffer type overridden to %s\n",
23722379
tensor_name.c_str(),
23732380
ggml_nbytes(t_meta) / 1024 / 1024, ggml_type_name(t_meta->type),
@@ -18249,6 +18256,7 @@ llama_model_params llama_model_default_params() {
1824918256
/*.use_mmap =*/ true,
1825018257
/*.use_mlock =*/ false,
1825118258
/*.check_tensors =*/ false,
18259+
/*.use_extra_bufts =*/ true,
1825218260
};
1825318261

1825418262
#ifdef GGML_USE_METAL

0 commit comments

Comments
 (0)