diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index c36c12d6579ac..2db5c4e0f85e5 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -114,6 +114,9 @@ extern "C" { void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); // wait for an event on on a different stream void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); + + // (optional) sort/optimize the nodes in the graph + void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph); }; struct ggml_backend { diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index f615ab4beecb1..7646f3f1346a4 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) backend->iface.event_wait(backend, event); } +static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) { + GGML_ASSERT(backend); + if (backend->iface.optimize_graph != NULL) { + backend->iface.optimize_graph(backend, cgraph); + } +} + // Backend device const char * ggml_backend_dev_name(ggml_backend_dev_t device) { @@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra struct ggml_backend_sched_split * split = &sched->splits[i]; split->graph = ggml_graph_view(graph, split->i_start, split->i_end); + // Optimize this split of the graph. This needs to happen before we make graph_copy, + // so they are in sync. + ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph); + // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split for (int j = 0; j < split->n_inputs; j++) { assert(graph_copy->size > (graph_copy->n_nodes + 1)); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index aeac2e57449a2..cdfc5a9bc2340 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = { /* .graph_compute = */ ggml_backend_blas_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_blas_guid(void) { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 756ad8dfad7af..17d10dbf38c83 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2689,6 +2689,7 @@ static const ggml_backend_i ggml_backend_cann_interface = { /* .graph_compute = */ ggml_backend_cann_graph_compute, /* .event_record = */ ggml_backend_cann_event_record, /* .event_wait = */ ggml_backend_cann_event_wait, + /* .optimize_graph = */ NULL, }; /** diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 0c15165f1b88e..2b81f8b9afa22 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = { /* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_cpu_guid(void) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 0c01eb6fa8359..6e55c75f589e5 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3135,6 +3135,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_cuda_guid() { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index c1a0a2bef171e..273ac1658005b 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -6513,6 +6513,7 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { /* .graph_compute = */ ggml_backend_metal_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_metal_guid(void) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 727163b7fdf95..b188c5af34562 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2838,6 +2838,7 @@ static ggml_backend_i ggml_backend_opencl_i = { /* .graph_compute = */ ggml_backend_opencl_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; ggml_backend_t ggml_backend_opencl_init(void) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index e84ff93efc32c..d4833068d0016 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -795,6 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = { /* .graph_compute = */ ggml_backend_rpc_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 877fbf7e8626d..619ccaefc0bf8 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4063,6 +4063,7 @@ static ggml_backend_i ggml_backend_sycl_interface = { /* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .event_record = */ ggml_backend_sycl_event_record, /* .event_wait = */ ggml_backend_sycl_event_wait, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_sycl_guid() { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index cd1c66ba7b476..bfdb662acf47a 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -582,6 +582,7 @@ struct vk_device_struct { bool disable_fusion; bool disable_host_visible_vidmem; bool allow_sysmem_fallback; + bool disable_optimize_graph; #ifdef GGML_VULKAN_MEMORY_DEBUG std::unique_ptr memory_logger; @@ -3502,6 +3503,9 @@ static vk_device ggml_vk_get_device(size_t idx) { const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH"); + device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr; + bool fp16_storage = false; bool fp16_compute = false; bool maintenance4_support = false; @@ -11633,6 +11637,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg UNUSED(backend); } +// Sort the graph for improved parallelism. +static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_optimize_graph) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + // TODO: enable async and synchronize static ggml_backend_i ggml_backend_vk_interface = { /* .get_name = */ ggml_backend_vk_name, @@ -11648,6 +11777,7 @@ static ggml_backend_i ggml_backend_vk_interface = { /* .graph_compute = */ ggml_backend_vk_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ ggml_vk_optimize_graph, }; static ggml_guid_t ggml_backend_vk_guid() { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index e5df883c1367e..70522a134e9c1 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -665,6 +665,7 @@ static ggml_backend_i ggml_backend_webgpu_i = { /* .graph_compute = */ ggml_backend_webgpu_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; /* End GGML Backend Interface */ diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index 7507a52aeae31..a4c51ab4e7c4a 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -586,6 +586,7 @@ static ggml_backend_i ggml_backend_zdnn_i = { /* .graph_compute = */ ggml_backend_zdnn_graph_compute, /* .event_record = */ NULL, /* .event_wait = */ NULL, + /* .optimize_graph = */ NULL, }; static ggml_guid_t ggml_backend_zdnn_guid(void) {