Skip to content

Commit fb371c1

Browse files
committed
bench,common : add CPU extra buffer types
1 parent 8ad7b3e commit fb371c1

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

common/arg.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,6 +2347,25 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
23472347
buft_list[ggml_backend_buft_name(buft)] = buft;
23482348
}
23492349
}
2350+
2351+
// add CPU extra buffer types
2352+
{
2353+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
2354+
if (cpu_dev == nullptr) {
2355+
throw std::runtime_error("no CPU backend found");
2356+
}
2357+
2358+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
2359+
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
2360+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
2361+
if (ggml_backend_dev_get_extra_bufts_fn) {
2362+
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
2363+
while (extra_bufts && *extra_bufts) {
2364+
buft_list[ggml_backend_buft_name(*extra_bufts)] = *extra_bufts;
2365+
++extra_bufts;
2366+
}
2367+
}
2368+
}
23502369
}
23512370

23522371
for (const auto & override : string_split<std::string>(value, ',')) {

tools/llama-bench/llama-bench.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,25 @@ static cmd_params parse_cmd_params(int argc, char ** argv) {
702702
buft_list[ggml_backend_buft_name(buft)] = buft;
703703
}
704704
}
705+
706+
// add CPU extra buffer types
707+
{
708+
auto * cpu_dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
709+
if (cpu_dev == nullptr) {
710+
throw std::runtime_error("no CPU backend found");
711+
}
712+
713+
auto * cpu_reg = ggml_backend_dev_backend_reg(cpu_dev);
714+
auto ggml_backend_dev_get_extra_bufts_fn = (ggml_backend_dev_get_extra_bufts_t)
715+
ggml_backend_reg_get_proc_address(cpu_reg, "ggml_backend_dev_get_extra_bufts");
716+
if (ggml_backend_dev_get_extra_bufts_fn) {
717+
ggml_backend_buffer_type_t * extra_bufts = ggml_backend_dev_get_extra_bufts_fn(cpu_dev);
718+
while (extra_bufts && *extra_bufts) {
719+
buft_list[ggml_backend_buft_name(*extra_bufts)] = *extra_bufts;
720+
++extra_bufts;
721+
}
722+
}
723+
}
705724
}
706725
auto override_group_span_len = std::strcspn(value, ",");
707726
bool last_group = false;

0 commit comments

Comments
 (0)