diff --git a/common/common.cpp b/common/common.cpp index dd45f83c4..024228a9c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1378,6 +1378,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa } return true; } + if (arg == "--offload-only-active-experts" || arg == "-ooae") { + params.only_active_exps = true; + return true; + } if (arg == "--host") { CHECK_ARG params.hostname = argv[i]; @@ -2746,6 +2750,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.fused_up_gate = params.fused_up_gate; cparams.min_experts = params.min_experts; cparams.thresh_experts = params.thresh_experts; + cparams.only_active_experts = params.only_active_exps; cparams.type_k = kv_cache_type_from_str(params.cache_type_k); cparams.type_v = kv_cache_type_from_str(params.cache_type_v); diff --git a/common/common.h b/common/common.h index 5ba6ad609..38027ca27 100644 --- a/common/common.h +++ b/common/common.h @@ -223,6 +223,7 @@ struct gpt_params { bool repack_tensors = false; // repack tensors if interleaved variant is available bool use_thp = false; // use transparent huge pages (linux only) bool validate_quants = false; // if true, check for NaNs while loading the model + bool only_active_exps = false; // if true, offload only active experts (relevant only for hybrid CPU/GPU) std::string cache_type_k = "f16"; // KV cache data type for the K std::string cache_type_v = "f16"; // KV cache data type for the V diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 1110ff3aa..6c843fa81 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -210,6 +210,7 @@ extern "C" { // enable or disable op offload for a given op GGML_API void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op op, bool on_or_off); + GGML_API void ggml_backend_sched_set_only_active_experts(ggml_backend_sched_t sched, bool on_or_off); // // Utils diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 5924805bd..1c9fcc248 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -1493,7 +1493,7 @@ add_library(ggml ../include/ggml-backend.h ggml.c ggml-alloc.c - ggml-backend.c + ggml-backend.cpp ggml-quants.c ggml-quants.h ${GGML_SOURCES_CUDA} ${GGML_HEADERS_CUDA} diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.cpp similarity index 92% rename from ggml/src/ggml-backend.c rename to ggml/src/ggml-backend.cpp index 07b879f12..4e9ef4739 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.cpp @@ -3,12 +3,14 @@ #include "ggml-impl.h" #include "ggml-rpc.h" -#include -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include +#include +#include +#include #define IK_PRINT_TIMING 0 @@ -60,9 +62,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_buffer_init( struct ggml_backend_buffer_i iface, ggml_backend_buffer_context_t context, size_t size) { - ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer)); - - (*buffer) = (struct ggml_backend_buffer) { + ggml_backend_buffer_t buffer = new ggml_backend_buffer { /* .interface = */ iface, /* .buft = */ buft, /* .context = */ context, @@ -200,6 +200,7 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) { void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + if (offset + size > ggml_nbytes(tensor)) fprintf(stderr, "%s(%s): offset = %zu, size = %zu, nbytes = %zu\n", __func__, tensor->name, offset, size, ggml_nbytes(tensor)); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); if (backend->iface.set_tensor_async == NULL) { @@ -442,6 +443,29 @@ static size_t ggml_backend_registry_count = 0; GGML_CALL static ggml_backend_t ggml_backend_reg_cpu_init(const char * params, void * user_data); +#ifdef GGML_USE_CUDA +extern "C" GGML_CALL void ggml_backend_cuda_reg_devices(void); +#endif +#ifdef GGML_USE_SYCL +extern "C" void ggml_backend_sycl_reg_devices(void); +#endif +#ifdef GGML_USE_METAL +extern "C" GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); +extern "C" GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); +#endif +#ifdef GGML_USE_VULKAN +extern "C" GGML_CALL int ggml_backend_vk_reg_devices(void); +#endif +#ifdef GGML_USE_KOMPUTE +extern "C" GGML_CALL void ggml_backend_kompute_reg_devices(void); +#endif +#ifdef GGML_USE_CANN +extern "C" GGML_CALL int ggml_backend_cann_reg_devices(void); +#endif +#ifdef GGML_USE_RPC +extern "C" GGML_CALL void ggml_backend_rpc_reg_devices(void); +#endif + GGML_CALL static void ggml_backend_registry_init(void) { static bool initialized = false; @@ -455,37 +479,29 @@ GGML_CALL static void ggml_backend_registry_init(void) { // add forward decls here to avoid including the backend headers #ifdef GGML_USE_CUDA - extern GGML_CALL void ggml_backend_cuda_reg_devices(void); ggml_backend_cuda_reg_devices(); #endif #ifdef GGML_USE_SYCL - extern void ggml_backend_sycl_reg_devices(void); ggml_backend_sycl_reg_devices(); #endif #ifdef GGML_USE_METAL - extern GGML_CALL ggml_backend_t ggml_backend_reg_metal_init(const char * params, void * user_data); - extern GGML_CALL ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void); ggml_backend_register("Metal", ggml_backend_reg_metal_init, ggml_backend_metal_buffer_type(), NULL); #endif #ifdef GGML_USE_VULKAN - extern GGML_CALL int ggml_backend_vk_reg_devices(void); ggml_backend_vk_reg_devices(); #endif #ifdef GGML_USE_KOMPUTE - extern GGML_CALL void ggml_backend_kompute_reg_devices(void); ggml_backend_kompute_reg_devices(); #endif #ifdef GGML_USE_CANN - extern GGML_CALL int ggml_backend_cann_reg_devices(void); ggml_backend_cann_reg_devices(); #endif #ifdef GGML_USE_RPC - extern GGML_CALL void ggml_backend_rpc_reg_devices(void); ggml_backend_rpc_reg_devices(); #endif } @@ -495,11 +511,11 @@ GGML_CALL void ggml_backend_register(const char * name, ggml_backend_init_fn ini size_t id = ggml_backend_registry_count; - ggml_backend_registry[id] = (struct ggml_backend_reg) { + ggml_backend_registry[id] = ggml_backend_reg { /* .name = */ {0}, /* .fn = */ init_fn, /* .default_buffer_type = */ default_buffer_type, - /* .user_data = */ user_data, + /* .user_data = */ user_data }; snprintf(ggml_backend_registry[id].name, sizeof(ggml_backend_registry[id].name), "%s", name); @@ -804,13 +820,13 @@ struct ggml_backend_plan_cpu { GGML_CALL static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend_t backend, const struct ggml_cgraph * cgraph) { struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context; - struct ggml_backend_plan_cpu * cpu_plan = malloc(sizeof(struct ggml_backend_plan_cpu)); + struct ggml_backend_plan_cpu * cpu_plan = (ggml_backend_plan_cpu *)malloc(sizeof(struct ggml_backend_plan_cpu)); cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads); cpu_plan->cgraph = *cgraph; // FIXME: deep copy if (cpu_plan->cplan.work_size > 0) { - cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size); + cpu_plan->cplan.work_data = (uint8_t *)malloc(cpu_plan->cplan.work_size); if (cpu_plan->cplan.work_data == NULL) { free(cpu_plan); return NULL; @@ -854,7 +870,7 @@ GGML_CALL static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t } cpu_ctx->work_size = cplan.work_size; } - cplan.work_data = cpu_ctx->work_data; + cplan.work_data = (uint8_t *)cpu_ctx->work_data; cplan.abort_callback = cpu_ctx->abort_callback; cplan.abort_callback_data = cpu_ctx->abort_callback_data; @@ -915,7 +931,7 @@ static ggml_guid_t ggml_backend_cpu_guid(void) { } ggml_backend_t ggml_backend_cpu_init(void) { - struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context)); + struct ggml_backend_cpu_context * ctx = (ggml_backend_cpu_context *)malloc(sizeof(struct ggml_backend_cpu_context)); if (ctx == NULL) { return NULL; } @@ -926,13 +942,13 @@ ggml_backend_t ggml_backend_cpu_init(void) { ctx->abort_callback = NULL; ctx->abort_callback_data = NULL; - ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend)); + ggml_backend_t cpu_backend = (ggml_backend_t)malloc(sizeof(struct ggml_backend)); if (cpu_backend == NULL) { free(ctx); return NULL; } - *cpu_backend = (struct ggml_backend) { + *cpu_backend = ggml_backend { /* .guid = */ ggml_backend_cpu_guid(), /* .interface = */ cpu_backend_i, /* .context = */ ctx @@ -1144,6 +1160,7 @@ struct ggml_backend_sched { uint32_t op_offload[(GGML_OP_COUNT + 31)/32]; + bool only_active_experts; bool debug; }; @@ -1164,6 +1181,11 @@ void ggml_backend_sched_set_op_offload(ggml_backend_sched_t sched, enum ggml_op } } +void ggml_backend_sched_set_only_active_experts(ggml_backend_sched_t sched, bool on_or_off) { + if (!sched) return; + sched->only_active_experts = on_or_off; +} + static inline bool ggml_backend_sched_offload_enabled(ggml_backend_sched_t sched, enum ggml_op op) { int int_op = (int)op; if (!sched || op < 0 || op >= GGML_OP_COUNT) return false; @@ -1630,7 +1652,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg i_split++; if (i_split >= sched->splits_capacity) { sched->splits_capacity *= 2; - sched->splits = realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); + sched->splits = (ggml_backend_sched_split *)realloc(sched->splits, sched->splits_capacity * sizeof(struct ggml_backend_sched_split)); GGML_ASSERT(sched->splits != NULL); } GGML_ASSERT(i_split < GGML_SCHED_MAX_SPLITS); @@ -1720,8 +1742,8 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg int graph_size = graph->n_nodes + sched->n_splits*GGML_SCHED_MAX_SPLIT_INPUTS*2; if (sched->graph.size < graph_size) { sched->graph.size = graph_size; - sched->graph.nodes = realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); - sched->graph.leafs = realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *)); + sched->graph.nodes = (ggml_tensor **)realloc(sched->graph.nodes, graph_size * sizeof(struct ggml_tensor *)); + sched->graph.leafs = (ggml_tensor **)realloc(sched->graph.leafs, graph_size * sizeof(struct ggml_tensor *)); GGML_ASSERT(sched->graph.nodes != NULL); GGML_ASSERT(sched->graph.leafs != NULL); } @@ -1844,6 +1866,14 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s int split_backend_id = split->backend_id; ggml_backend_t split_backend = sched->backends[split_backend_id]; + int cur_arg = 0; + std::vector ids; + std::set unique_ids; + + //printf("Graph split %d has %d inputs:\n", i, split->n_inputs); + //for (int j = 0; j < split->n_inputs; j++) printf(" %s, %s\n", split->inputs[j]->name, + // split->inputs[j]->src[0] ? split->inputs[j]->src[0]->name : "none"); + // copy the input tensors to the split backend for (int j = 0; j < split->n_inputs; j++) { ggml_backend_t input_backend = ggml_backend_sched_get_tensor_backend(sched, split->inputs[j]); @@ -1865,6 +1895,71 @@ static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t s } else { ggml_backend_synchronize(split_backend); } + + ggml_tensor * node = split->graph.nodes[0]; + if (sched->only_active_experts && split->graph.n_nodes > 0 && + ggml_backend_buffer_get_usage(input->buffer) == GGML_BACKEND_BUFFER_USAGE_WEIGHTS && + ggml_backend_buffer_is_host(input->buffer) && + node->src[cur_arg] == input_cpy && + (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE)) { + + if (ids.empty()) { + // find the ids + ggml_tensor * ids_tensor = node->op == GGML_OP_MUL_MAT_ID ? node->src[2] : node->src[3]; + ids.resize(ggml_nbytes(ids_tensor) / sizeof(int32_t)); + ggml_backend_synchronize(input_backend); + + ggml_backend_tensor_get_async(split_backend, ids_tensor, ids.data(), 0, ggml_nbytes(ids_tensor)); + + ggml_backend_synchronize(split_backend); + + for (int64_t i1 = 0; i1 < ids_tensor->ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ids_tensor->ne[0]; i0++) { + int32_t id = ids[i1 * ids_tensor->nb[1]/sizeof(int32_t) + i0 * ids_tensor->nb[0]/sizeof(int32_t)]; + unique_ids.insert(id); + } + } + + // group consecutive experts and copy them together + GGML_ASSERT(!unique_ids.empty()); + + } + + auto it = unique_ids.begin(); + int32_t first_id = *it; + int32_t last_id = first_id; + + auto copy_experts = [&](int32_t first_id, int32_t last_id) { + const size_t expert_size = (node->op == GGML_OP_MUL_MAT_ID || node->op == GGML_OP_MOE_FUSED_UP_GATE) ? input->nb[2] : input->nb[1]; + const size_t expert_offset = first_id * expert_size; + const size_t expert_size_copy = (last_id - first_id + 1) * expert_size; + const size_t padding = 512; + const size_t padding_end = last_id < input->ne[2] - 1 ? std::min(expert_size, padding) : 0; + + ggml_backend_tensor_set_async(split_backend, + input_cpy, + (const uint8_t *)input->data + expert_offset, expert_offset, + // copy a bit extra to ensure there are no NaNs in the padding + expert_size_copy + padding_end); + + }; + + for (++it; it != unique_ids.end(); ++it) { + const int32_t id = *it; + + if (id == last_id + 1) { + last_id = id; + continue; + } + + copy_experts(first_id, last_id); + + first_id = id; + last_id = id; + } + copy_experts(first_id, last_id); + if (node->op == GGML_OP_MOE_FUSED_UP_GATE) ++cur_arg; + } else // try async copy, but if not possible, we can still use a sync copy without synchronizing the dst backend, since we handle the synchronization here with multiple copies and events // TODO: add public function to facilitate this, since applications do not have direct access to the backend interface if (!split_backend->iface.cpy_tensor_async || !split_backend->iface.cpy_tensor_async(input_backend, split_backend, input, input_cpy)) { @@ -1950,7 +2045,7 @@ ggml_backend_sched_t ggml_backend_sched_new( GGML_ASSERT(n_backends <= GGML_SCHED_MAX_BACKENDS); GGML_ASSERT(ggml_backend_is_cpu(backends[n_backends - 1])); // last backend must be CPU - struct ggml_backend_sched * sched = calloc(1, sizeof(struct ggml_backend_sched)); + struct ggml_backend_sched * sched = (ggml_backend_sched *)calloc(1, sizeof(struct ggml_backend_sched)); for (int i = 0; i < (GGML_OP_COUNT + 31)/32; ++i) sched->op_offload[i] = 0xffffffff; @@ -1961,20 +2056,20 @@ ggml_backend_sched_t ggml_backend_sched_new( // initialize hash table // FIXME: needs to be size*2 to account for leafs (do it in graph_split instead) sched->hash_set = ggml_hash_set_new(graph_size); - sched->hv_tensor_backend_ids = malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); - sched->hv_tensor_copies = malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); + sched->hv_tensor_backend_ids = (int *)malloc(sched->hash_set.size * sizeof(sched->hv_tensor_backend_ids[0])); + sched->hv_tensor_copies = (ggml_tensor **)malloc(sched->hash_set.size * sched->n_backends * sched->n_copies * sizeof(struct ggml_tensor *)); const size_t nodes_size = graph_size + GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2; - sched->node_backend_ids = calloc(nodes_size, sizeof(sched->node_backend_ids[0])); - sched->leaf_backend_ids = calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); - sched->prev_node_backend_ids = calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); - sched->prev_leaf_backend_ids = calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); + sched->node_backend_ids = (int *)calloc(nodes_size, sizeof(sched->node_backend_ids[0])); + sched->leaf_backend_ids = (int *)calloc(nodes_size, sizeof(sched->leaf_backend_ids[0])); + sched->prev_node_backend_ids = (int *)calloc(nodes_size, sizeof(sched->prev_node_backend_ids[0])); + sched->prev_leaf_backend_ids = (int *)calloc(nodes_size, sizeof(sched->prev_leaf_backend_ids[0])); sched->context_buffer_size = GGML_SCHED_MAX_SPLITS*GGML_SCHED_MAX_SPLIT_INPUTS*2*sizeof(struct ggml_tensor) + ggml_graph_overhead_custom(graph_size, false); - sched->context_buffer = malloc(sched->context_buffer_size); + sched->context_buffer = (char *)malloc(sched->context_buffer_size); const int initial_splits_capacity = 16; - sched->splits = calloc(initial_splits_capacity, sizeof(sched->splits[0])); + sched->splits = (ggml_backend_sched_split *)calloc(initial_splits_capacity, sizeof(sched->splits[0])); sched->splits_capacity = initial_splits_capacity; for (int b = 0; b < n_backends; b++) { @@ -2219,8 +2314,8 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); - struct ggml_tensor ** node_copies = calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT - bool * node_init = calloc(hash_set.size, sizeof(node_init[0])); + struct ggml_tensor ** node_copies = (ggml_tensor **)calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT + bool * node_init = (bool *)calloc(hash_set.size, sizeof(node_init[0])); struct ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead()*hash_set.size + ggml_graph_overhead_custom(graph->size, false), @@ -2238,7 +2333,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s free(node_init); ggml_free(ctx_allocated); ggml_free(ctx_unallocated); - return (struct ggml_backend_graph_copy) { + return { /* .buffer = */ NULL, /* .ctx_allocated = */ NULL, /* .ctx_unallocated = */ NULL, @@ -2261,7 +2356,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s free(node_init); ggml_free(ctx_allocated); ggml_free(ctx_unallocated); - return (struct ggml_backend_graph_copy) { + return { /* .buffer = */ NULL, /* .ctx_allocated = */ NULL, /* .ctx_unallocated = */ NULL, @@ -2290,7 +2385,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s free(node_copies); free(node_init); - return (struct ggml_backend_graph_copy) { + return { /* .buffer = */ buffer, /* .ctx_allocated = */ ctx_allocated, /* .ctx_unallocated = */ ctx_unallocated, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index ac8246b92..5272a38bf 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -4288,8 +4288,9 @@ GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const if (batch_size < min_batch_size) return false; int64_t n_experts_tot = op->src[0]->ne[2]; int64_t n_experts_active = ids->ne[0]; - //printf("%s(%s): op->ne[2] = %ld, n_experts_tot = %ld, n_experts_active = %ld, ids: %s, %ld x %ld x %ld x %ld\n", __func__, op->name, op->ne[2], n_experts_tot, n_experts_active, ids->name, ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3]); - return batch_size*n_experts_active >= min_batch_size*n_experts_tot; + bool should_offload = batch_size*n_experts_active >= min_batch_size*n_experts_tot; + //printf("%s(%s): op->ne[2] = %ld, n_experts_tot = %ld, n_experts_active = %ld, ids: %s, %ld x %ld x %ld x %ld -> %d (%ld, %ld)\n", __func__, op->name, op->ne[2], n_experts_tot, n_experts_active, ids->name, ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], should_offload, batch_size*n_experts_active, min_batch_size*n_experts_tot); + return should_offload; } return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS; diff --git a/include/llama.h b/include/llama.h index 27b0298f9..28b54392c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -424,6 +424,7 @@ extern "C" { bool fused_up_gate; // whether to use fused up/gate op [EXPERIMENTAL] int min_experts; float thresh_experts; + bool only_active_experts; // Abort callback // if it returns true, execution of llama_decode() will be aborted diff --git a/src/llama.cpp b/src/llama.cpp index d793d0062..8154bce85 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18957,6 +18957,7 @@ struct llama_context_params llama_context_default_params() { /*.fused_up_gate =*/ true, /*.min_experts =*/ -1, /*.thtesh_experts =*/ 0.0f, + /*.only_active_experts =*/ false, /*.abort_callback =*/ nullptr, /*.abort_callback_data =*/ nullptr, /*.offload_policy =*/ nullptr, @@ -19548,6 +19549,11 @@ struct llama_context * llama_new_context_with_model( } } + if (params.only_active_experts) { + LLAMA_LOG_INFO("XXXXXXXXXXXXXXXXXXXXX Setting only active experts offload\n"); + ggml_backend_sched_set_only_active_experts(ctx->sched, true); + } + return ctx; }