Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ggml/include/ggml-backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ extern "C" {

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);

GGML_API GGML_CALL ggml_backend_buffer_type_t ggml_backend_numa_buffer_type(void);

#ifdef GGML_USE_CPU_HBM
GGML_API ggml_backend_buffer_type_t ggml_backend_cpu_hbm_buffer_type(void);
#endif
Expand Down
90 changes: 90 additions & 0 deletions ggml/src/ggml-backend.c
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,96 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(void * ptr, siz
return ggml_backend_buffer_init(ggml_backend_cpu_buffer_type(), cpu_backend_buffer_i_from_ptr, ptr, size);
}



// NUMA buffer interface - similar to CPU, but with pages allocated accordingly to a NUMA first-touch policy

#include <sys/mman.h>

GGML_CALL static void ggml_backend_numa_buffer_free_buffer(ggml_backend_buffer_t buffer) {
if (munmap((char *) buffer->context, buffer->size)) {
//GGML_LOG_WARN("warning: munmap failed: %s\n", strerror(errno));
}
}

GGML_CALL static void ggml_backend_numa_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
if (posix_madvise(buffer->context, buffer->size, POSIX_MADV_DONTNEED)) {
//GGML_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_DONTNEED) failed: %s\n",
// strerror(errno));
}
}

GGML_CALL static const char * ggml_backend_numa_buffer_name(ggml_backend_buffer_t buffer) {
return "CPU NUMA";

GGML_UNUSED(buffer);
}

static const struct ggml_backend_buffer_i ggml_backend_numa_buffer_i = {
/* .get_name = */ ggml_backend_numa_buffer_name,
/* .free_buffer = */ ggml_backend_numa_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .init_tensor = */ NULL, // no initialization required
// / .memset_tensor = / ggml_backend_cpu_buffer_memset_tensor,
/* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor,
/* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor,
/* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor,
/* .clear = */ ggml_backend_numa_buffer_clear,
/* .reset = */ NULL,
};

// NUMA buffer type - similar to CPU, but with pages allocated accordingly to a NUMA first-touch policy

GGML_CALL static const char * ggml_backend_numa_buffer_type_get_name(ggml_backend_buffer_type_t buft) {
return "NUMA";

GGML_UNUSED(buft);
}

GGML_CALL static ggml_backend_buffer_t ggml_backend_numa_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
int flags = MAP_SHARED | MAP_ANONYMOUS;
void * data = mmap(NULL, size, PROT_READ|PROT_WRITE, flags, -1, 0);
if (data == MAP_FAILED) {
//GGML_LOG_ERROR("%s: failed to allocate buffer of size %zu\n", __func__, size);
return NULL;
}
if (posix_madvise(data, size, POSIX_MADV_RANDOM)) {
//GGML_LOG_WARN("warning: posix_madvise(.., POSIX_MADV_RANDOM) failed: %s\n",
// strerror(errno));
}

return ggml_backend_buffer_init(buft, ggml_backend_numa_buffer_i, data, size);
}

GGML_CALL static size_t ggml_backend_numa_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) {
return TENSOR_ALIGNMENT;

GGML_UNUSED(buft);
}

GGML_CALL static bool ggml_backend_numa_buffer_type_is_host(ggml_backend_buffer_type_t buft) {
return true;

GGML_UNUSED(buft);
}

GGML_CALL ggml_backend_buffer_type_t ggml_backend_numa_buffer_type(void) {
static struct ggml_backend_buffer_type ggml_backend_numa_buffer_type = {
/* .iface = */ {
/* .get_name = */ ggml_backend_numa_buffer_type_get_name,
/* .alloc_buffer = */ ggml_backend_numa_buffer_type_alloc_buffer,
/* .get_alignment = */ ggml_backend_numa_buffer_type_get_alignment,
/* .get_max_size = */ NULL, // defaults to SIZE_MAX
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .is_host = */ ggml_backend_numa_buffer_type_is_host,
},
/* .context = */ NULL,
};

return &ggml_backend_numa_buffer_type;
}


GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data) {
return ggml_backend_cpu_init();

Expand Down
14 changes: 11 additions & 3 deletions src/llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2260,7 +2260,7 @@ static ggml_backend_buffer_type_t llama_default_buffer_type_cpu(bool host_buffer
#endif

if (buft == nullptr) {
buft = ggml_backend_cpu_buffer_type();
buft = ggml_backend_numa_buffer_type();
}
return buft;

Expand Down Expand Up @@ -3249,14 +3249,22 @@ static bool llama_kv_cache_init(

bool warn = true;
int n_mla = 0;
//auto * reg = ggml_backend_dev_backend_reg(ggml_backend_dev_by_type(GGML_BACKEND_TYPE_CPU));
//auto * is_numa_fn = (decltype(ggml_is_numa) *) ggml_backend_reg_get_proc_address(reg, "ggml_backend_cpu_is_numa");
//bool is_numa = is_numa_fn();
if (!offload) {
LLAMA_LOG_INFO("%s: NUMA usage detected, using NUMA-aware buffer for KV cache\n", __func__);
}



for (int i = 0; i < (int) n_layer; i++) {
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s();
const uint32_t n_head = hparams.n_head(i);
const uint32_t n_head_kv = hparams.n_head_kv(i);
const uint32_t n_embd_head_k= hparams.n_embd_head_k;


struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
ggml_tensor * k;
ggml_tensor * v;
Expand All @@ -3265,7 +3273,7 @@ static bool llama_kv_cache_init(
const uint32_t n_embd_head_qk_rope = hparams.n_rot;
const uint32_t n_embd_head_qk_nope = hparams.n_embd_head_k - hparams.n_rot;
const uint32_t kv_lora_rank = hparams.n_lora_kv;
LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
//LLAMA_LOG_INFO("%s: layer %d: n_embd_head_qk_rope = %d, kv_lora_rank = %d\n", __func__, i, n_embd_head_qk_rope, kv_lora_rank);
if (cparams.flash_attn) {
ggml_tensor * kv = ggml_new_tensor_2d(ctx, cache.type_k, kv_lora_rank + n_embd_head_qk_rope, kv_size);
ggml_format_name(kv, "cache_kv_l%d", i);
Expand Down