From 4cb12309af8ab1a8be93d690369dbcd0f1ed0ed9 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 1 Nov 2025 14:30:00 +0800 Subject: [PATCH 01/10] CUDA: add stream-based concurrency --- ggml/src/ggml-cuda/common.cuh | 46 ++++- ggml/src/ggml-cuda/ggml-cuda.cu | 298 ++++++++++++++++++++++++++++++-- 2 files changed, 324 insertions(+), 20 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 25e9308d7..e7d2071e7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -964,6 +964,32 @@ struct ggml_cuda_graph { #endif }; +struct ggml_cuda_concurrent_event { + std::vector per_stream_events; + cudaEvent_t fork_event; + cudaEvent_t join_event; + + int n_streams = 0; + std::unordered_map stream_mapping; + + const ggml_tensor * join_node; + + ggml_cuda_concurrent_event() = default; + + explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) { + per_stream_events.resize(n_streams); + + for (size_t i = 0; i < per_stream_events.size(); ++i) { + cudaEventCreateWithFlags(&per_stream_events[i], cudaEventDisableTiming); + } + + cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming); + cudaEventCreateWithFlags(&join_event, cudaEventDisableTiming); + } +}; + +using ggml_cuda_stream_context = std::unordered_map; + struct ggml_backend_cuda_context { int device; std::string name; @@ -974,11 +1000,15 @@ struct ggml_backend_cuda_context { std::unique_ptr cuda_graph; + int curr_stream_no = 0; + explicit ggml_backend_cuda_context(int device) : device(device), name(GGML_CUDA_NAME + std::to_string(device)) { } + ggml_cuda_stream_context concurrent_stream_context; + ~ggml_backend_cuda_context(); cudaStream_t stream(int device, int stream) { @@ -989,9 +1019,9 @@ struct ggml_backend_cuda_context { return streams[device][stream]; } - cudaStream_t stream() { - return stream(device, 0); - } + cudaStream_t stream() { return stream(device, curr_stream_no); } + + ggml_cuda_stream_context & stream_context() { return concurrent_stream_context; } cublasHandle_t cublas_handle(int device) { if (cublas_handles[device] == nullptr) { @@ -1007,15 +1037,15 @@ struct ggml_backend_cuda_context { } // pool - std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; + std::unique_ptr pools[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; - static std::unique_ptr new_pool_for_device(int device); + static std::unique_ptr new_pool_for_device(int device, int stream_no); ggml_cuda_pool & pool(int device) { - if (pools[device] == nullptr) { - pools[device] = new_pool_for_device(device); + if (pools[device][curr_stream_no] == nullptr) { + pools[device][curr_stream_no] = new_pool_for_device(device, curr_stream_no); } - return *pools[device]; + return *pools[device][curr_stream_no]; } ggml_cuda_pool & pool() { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7d792e60c..c6223891a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -521,7 +521,8 @@ struct ggml_cuda_pool_vmm : public ggml_cuda_pool { }; #endif // defined(GGML_USE_VMM) -std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device) { +std::unique_ptr ggml_backend_cuda_context::new_pool_for_device(int device, + [[maybe_unused]] int stream_no) { #if defined(GGML_USE_VMM) if (ggml_cuda_info().devices[device].vmm) { return std::unique_ptr(new ggml_cuda_pool_vmm(device)); @@ -3032,7 +3033,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, #ifndef NDEBUG const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); GGML_ASSERT(unary_ops.size() == num_unary); -#endif +#endif; //TODO: remove special case once ggml_can_fuse can handle empty nodes std::initializer_list topk_moe_ops = @@ -3188,18 +3189,44 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + bool is_concurrent_event_active = false; + ggml_cuda_concurrent_event * concurrent_event = nullptr; + while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. if (!use_cuda_graph || cuda_graph_update_required) { - [[maybe_unused]] int prev_i = 0; + ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if (is_concurrent_event_active) { + GGML_ASSERT(concurrent_event); + + if (node == concurrent_event->join_node) { + cuda_ctx->curr_stream_no = 0; + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + CUDA_CHECK(cudaEventRecord(concurrent_event->per_stream_events[i - 1], + cuda_ctx->stream(cuda_ctx->device, i))); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->per_stream_events[i - 1])); + } + + is_concurrent_event_active = false; + concurrent_event = nullptr; + + } else { + GGML_ASSERT (concurrent_event->stream_mapping.find(node) != concurrent_event->stream_mapping.end()); + const int stream_mapping = concurrent_event->stream_mapping[node]; + cuda_ctx->curr_stream_no = stream_mapping; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", stream_mapping, node->name); + } + } + prev_i = i; + #ifdef GGML_CUDA_DEBUG const int nodes_fused = i - prev_i - 1; - prev_i = i; if (nodes_fused > 0) { GGML_LOG_INFO("nodes_fused: %d\n", nodes_fused); } @@ -3209,6 +3236,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + + // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { @@ -3472,16 +3501,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + //TODO: fix this + static const bool graph_opt = (getenv("GGML_CUDA_GRAPH_OPT") != nullptr) && atoi(getenv("GGML_CUDA_GRAPH_OPT")) == 1; + if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; + if (strncmp(cgraph->nodes[i+2]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) { + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; + } } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; + if (strncmp(cgraph->nodes[i+1]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) { + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; + } } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { @@ -3501,13 +3537,35 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } #else GGML_UNUSED(integrated); -#endif // NDEBUG +#endif // NDEBUG bool ok = ggml_cuda_compute_forward(*cuda_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); } GGML_ASSERT(ok); + + if (!is_concurrent_event_active) { + //const ggml_tensor * adjusted_node = node; + // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD), + // we can safely use the previous node to check if it can be forked + if (stream_ctx.find(node) != stream_ctx.end()) { + concurrent_event = &stream_ctx[node]; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + + is_concurrent_event_active = true; + } + } } } @@ -3647,6 +3705,222 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev } } +static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + static bool enable_graph_optimization = [] { + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + return env != nullptr && atoi(env) == 1; + }(); + + if (!enable_graph_optimization) { + return; + } + + GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "cuda graph optimization is only supported on single GPU"); + GGML_LOG_DEBUG("Optimizing CUDA graph %p %d\n", cgraph->nodes, cgraph->n_nodes); + + ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); + stream_context.clear(); + + std::unordered_map fan_out; + std::unordered_map node_indices; + + const auto & is_empty = [](const ggml_tensor * node) -> bool { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || + node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + const auto & is_src_of = [](const ggml_tensor * dst, const ggml_tensor * src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor * dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor * src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { + const ggml_tensor * node = cgraph->nodes[node_idx]; + node_indices[node] = node_idx; + + if (is_empty(node)) { + continue; + } + for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) { + const ggml_tensor * node = cgraph->nodes[node_idx]->src[src_idx]; + //TODO: check why nrows > 1 fails, probably related to CUDA graphs + if (node && !is_empty(node) && ggml_nrows(node) <= 1) { + fan_out[node] += 1; + } + } + } + + //Target Q, K, V + const int min_fan_out = 3; + const int max_fan_out = 3; + + std::vector> concurrent_node_ranges; + for (const auto & [root_node, count] : fan_out) { + if (count >= min_fan_out && count <= max_fan_out) { + const int root_node_idx = node_indices[root_node]; + + bool is_part_of_event = false; + for (const auto & [start, end] : concurrent_node_ranges) { + if (root_node_idx >= start && root_node_idx <= end) { + is_part_of_event = true; + } + } + + if (is_part_of_event) { + continue; + } + + std::vector> nodes_per_branch; + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * node = cgraph->nodes[i]; + if (!is_empty(node) && is_src_of(node, root_node)) { + nodes_per_branch.push_back({ node }); + } + } + + GGML_ASSERT(nodes_per_branch.size() == (size_t) count); + + //find the join point + const ggml_tensor * join_node = nullptr; + + auto belongs_to_branch = [&](const ggml_tensor * node, std::vector & branch) -> bool { + for (const ggml_tensor * n : branch) { + if (n == node) { + return false; + } + + if (is_src_of(node, n)) { + return true; + } + } + return false; + }; + + for (int i = root_node_idx + 1; i < cgraph->n_nodes; ++i) { + const ggml_tensor * curr_node = cgraph->nodes[i]; + + int num_joins = 0; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) { + num_joins++; + } + } + + if (num_joins >= 2) { + join_node = curr_node; + break; + } + + bool found_branch = false; + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + if (belongs_to_branch(curr_node, nodes_per_branch[branch_idx])) { + //continue accumulating + nodes_per_branch[branch_idx].push_back(curr_node); + found_branch = true; + } else { + if (std::find(nodes_per_branch[branch_idx].begin(), nodes_per_branch[branch_idx].end(), + curr_node) != nodes_per_branch[branch_idx].end()) { + found_branch = true; + } + } + } + + if (!found_branch) { + if (is_empty(curr_node)) { + // we can put it in any branch because it will be ignored + nodes_per_branch[0].push_back({ curr_node }); + } + } + } + + if (join_node) { + //Create ggml_cuda_concurrent_event + ggml_cuda_concurrent_event concurrent_event(nodes_per_branch.size()); + concurrent_event.join_node = join_node; + + for (size_t branch_idx = 0; branch_idx < nodes_per_branch.size(); branch_idx++) { + for (const ggml_tensor * n : nodes_per_branch[branch_idx]) { + concurrent_event.stream_mapping[n] = branch_idx + 1; + } + } + + int fork_node_idx = node_indices[root_node]; + int join_node_idx = node_indices[join_node]; + + int current_branch_idx = 0; + int current_node_idx = fork_node_idx + 1; + const int n_branches = nodes_per_branch.size(); + + int total_branch_nodes = 0; + for (std::vector branch_nodes : nodes_per_branch) { + total_branch_nodes += branch_nodes.size(); + } + + // there are other nodes in the middle which are unaccounted for + // usually (cpy) nodes, then ignore this fork + if (join_node_idx - fork_node_idx - 1 != total_branch_nodes) { + GGML_LOG_DEBUG( + "Skipping %s because the number of nodes in the middle is not equal to the total number of " + "branch nodes %d != %d\n", + root_node->name, join_node_idx - fork_node_idx - 1, total_branch_nodes); + continue; + } + + GGML_ASSERT(cuda_ctx->stream_context().find(root_node) == cuda_ctx->stream_context().end()); + cuda_ctx->stream_context().emplace(root_node, concurrent_event); + GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node); + concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx); + + // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them + // example transformation: + // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] -> + // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn] + // TODO: This breaks fusion within streams, how do we fix this? + while (current_node_idx < join_node_idx) { + std::vector & branch_nodes = nodes_per_branch[current_branch_idx]; + + bool has_node = false; + for (std::vector branch_node : nodes_per_branch) { + has_node |= branch_node.size() > 0; + } + + GGML_ASSERT(has_node); + + if (branch_nodes.empty()) { + current_branch_idx = (current_branch_idx + 1) % n_branches; + continue; + } + + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + + // append all empty nodes + while (!branch_nodes.empty() && is_empty(branch_nodes.front())) { + cgraph->nodes[current_node_idx] = const_cast(branch_nodes.front()); + current_node_idx++; + branch_nodes.erase(branch_nodes.begin()); + } + + current_branch_idx = (current_branch_idx + 1) % n_branches; + } + } + } + } +} + static const ggml_backend_i ggml_backend_cuda_interface = { /* .get_name = */ ggml_backend_cuda_get_name, /* .free = */ ggml_backend_cuda_free, @@ -3661,7 +3935,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, - /* .graph_optimize = */ NULL, + /* .graph_optimize = */ ggml_backend_cuda_graph_optimize, }; static ggml_guid_t ggml_backend_cuda_guid() { From d9d6f27b056d01f4b1d5977f6bc9a2b0dbc361f6 Mon Sep 17 00:00:00 2001 From: Carl Philipp Klemm Date: Thu, 6 Nov 2025 13:35:19 +0100 Subject: [PATCH 02/10] HIP: fix hipStreamWaitEvent define and nodiscard warnings --- ggml/src/ggml-cuda/common.cuh | 7 ++++--- ggml/src/ggml-cuda/vendors/hip.h | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index e7d2071e7..aab837ef2 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -25,6 +25,7 @@ #include #include #include +#include #include #if defined(GGML_USE_HIP) @@ -980,11 +981,11 @@ struct ggml_cuda_concurrent_event { per_stream_events.resize(n_streams); for (size_t i = 0; i < per_stream_events.size(); ++i) { - cudaEventCreateWithFlags(&per_stream_events[i], cudaEventDisableTiming); + CUDA_CHECK(cudaEventCreateWithFlags(&per_stream_events[i], cudaEventDisableTiming)); } - cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming); - cudaEventCreateWithFlags(&join_event, cudaEventDisableTiming); + CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming)); + CUDA_CHECK(cudaEventCreateWithFlags(&join_event, cudaEventDisableTiming)); } }; diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 890c10364..b7d6edf7f 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -105,7 +105,7 @@ #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamPerThread hipStreamPerThread #define cudaStreamSynchronize hipStreamSynchronize -#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) +#define cudaStreamWaitEvent hipStreamWaitEvent #define cudaGraphExec_t hipGraphExec_t #define cudaGraphNode_t hipGraphNode_t #define cudaKernelNodeParams hipKernelNodeParams From 5f28c4ea7a783f14fd94fa20b56191db8fae713d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Nov 2025 11:08:05 +0800 Subject: [PATCH 03/10] ggml-cuda: fix fusion inside stream --- ggml/src/ggml-cuda/common.cuh | 10 ++++- ggml/src/ggml-cuda/ggml-cuda.cu | 65 +++++++++++++++++++++++---------- 2 files changed, 54 insertions(+), 21 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index aab837ef2..ca0e54cdd 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -989,7 +989,15 @@ struct ggml_cuda_concurrent_event { } }; -using ggml_cuda_stream_context = std::unordered_map; +struct ggml_cuda_stream_context { + std::vector original_graph; + std::unordered_map concurrent_events; + + void reset() { + original_graph.clear(); + concurrent_events.clear(); + } +}; struct ggml_backend_cuda_context { int device; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index c6223891a..23986446a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3199,6 +3199,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx [[maybe_unused]] int prev_i = 0; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + if (stream_ctx.concurrent_events.size() > 0) { + cgraph->nodes = const_cast(stream_ctx.original_graph.data()); + } for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -3222,6 +3225,26 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx cuda_ctx->curr_stream_no = stream_mapping; GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", stream_mapping, node->name); } + } else if (i - prev_i > 1) { + + //the previous node was fused + const ggml_tensor * prev_node = cgraph->nodes[i - 1]; + if (stream_ctx.concurrent_events.find(prev_node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[prev_node]; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + + is_concurrent_event_active = true; + } } prev_i = i; @@ -3501,23 +3524,16 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } - //TODO: fix this - static const bool graph_opt = (getenv("GGML_CUDA_GRAPH_OPT") != nullptr) && atoi(getenv("GGML_CUDA_GRAPH_OPT")) == 1; - if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) { - if (strncmp(cgraph->nodes[i+2]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) { - ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); - i += 2; - continue; - } + ggml_cuda_op_rms_norm_fused_add(*cuda_ctx, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); + i += 2; + continue; } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) { - if (strncmp(cgraph->nodes[i+1]->name, "attn_norm", strlen("attn_norm")) != 0 || !graph_opt) { - ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); - i++; - continue; - } + ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]); + i++; + continue; } if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) { @@ -3549,8 +3565,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx //const ggml_tensor * adjusted_node = node; // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD), // we can safely use the previous node to check if it can be forked - if (stream_ctx.find(node) != stream_ctx.end()) { - concurrent_event = &stream_ctx[node]; + if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[node]; GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); @@ -3721,7 +3737,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph GGML_LOG_DEBUG("Optimizing CUDA graph %p %d\n", cgraph->nodes, cgraph->n_nodes); ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); - stream_context.clear(); + stream_context.reset(); std::unordered_map fan_out; std::unordered_map node_indices; @@ -3767,6 +3783,15 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph const int max_fan_out = 3; std::vector> concurrent_node_ranges; + + //save the original graph + std::vector original_graph; + original_graph.reserve(cgraph->n_nodes); + for (int i = 0; i < cgraph->n_nodes; ++i) { + original_graph.push_back(cgraph->nodes[i]); + } + cuda_ctx->stream_context().original_graph = std::move(original_graph); + for (const auto & [root_node, count] : fan_out) { if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; @@ -3795,7 +3820,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph //find the join point const ggml_tensor * join_node = nullptr; - auto belongs_to_branch = [&](const ggml_tensor * node, std::vector & branch) -> bool { + const auto & belongs_to_branch = [&](const ggml_tensor * node, std::vector & branch) -> bool { for (const ggml_tensor * n : branch) { if (n == node) { return false; @@ -3878,8 +3903,9 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph continue; } - GGML_ASSERT(cuda_ctx->stream_context().find(root_node) == cuda_ctx->stream_context().end()); - cuda_ctx->stream_context().emplace(root_node, concurrent_event); + std::unordered_map & concurrent_events = cuda_ctx->stream_context().concurrent_events; + GGML_ASSERT(concurrent_events.find(root_node) == concurrent_events.end()); + concurrent_events.emplace(root_node, concurrent_event); GGML_LOG_DEBUG("Adding stream at node %s %p\n", root_node->name, root_node); concurrent_node_ranges.emplace_back(fork_node_idx, join_node_idx); @@ -3887,7 +3913,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph // example transformation: // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] -> // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn] - // TODO: This breaks fusion within streams, how do we fix this? while (current_node_idx < join_node_idx) { std::vector & branch_nodes = nodes_per_branch[current_branch_idx]; From 4000f113d0cea50937ceec93f96c7ff4600dcfff Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Nov 2025 16:23:25 +0800 Subject: [PATCH 04/10] ggml-cuda: fix bug w.r.t first stream launch --- ggml/src/ggml-cuda/ggml-cuda.cu | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 23986446a..bf0cae1db 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3199,6 +3199,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx [[maybe_unused]] int prev_i = 0; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); + if (stream_ctx.concurrent_events.size() > 0) { cgraph->nodes = const_cast(stream_ctx.original_graph.data()); } @@ -3232,7 +3233,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (stream_ctx.concurrent_events.find(prev_node) != stream_ctx.concurrent_events.end()) { concurrent_event = &stream_ctx.concurrent_events[prev_node]; - GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, prev_node->name); cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 GGML_ASSERT(cuda_ctx->curr_stream_no == 0); @@ -3244,6 +3245,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } is_concurrent_event_active = true; + cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; + GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); } } prev_i = i; From ac95faaa854bf9fa591ee919f19d75f5fbf63e8d Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Thu, 13 Nov 2025 20:49:49 +0800 Subject: [PATCH 05/10] ggml-cuda: format --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index bf0cae1db..2f09e399f 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3033,7 +3033,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, #ifndef NDEBUG const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); GGML_ASSERT(unary_ops.size() == num_unary); -#endif; +#endif //TODO: remove special case once ggml_can_fuse can handle empty nodes std::initializer_list topk_moe_ops = From 6562b77974d47fd9eb8cf3e09168c411f41a594c Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 14 Nov 2025 10:56:30 +0800 Subject: [PATCH 06/10] ggml-cuda: improve assert message --- ggml/src/ggml-cuda/ggml-cuda.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 2f09e399f..14dd97cb5 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3736,7 +3736,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph return; } - GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "cuda graph optimization is only supported on single GPU"); + GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend"); GGML_LOG_DEBUG("Optimizing CUDA graph %p %d\n", cgraph->nodes, cgraph->n_nodes); ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); From 1129188286612cbf40a251f3510b0419b0a85700 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 14 Nov 2025 23:26:51 +0800 Subject: [PATCH 07/10] ggml-cuda: use lambda instead of duplicating code --- ggml/src/ggml-cuda/ggml-cuda.cu | 57 ++++++++++++--------------------- 1 file changed, 20 insertions(+), 37 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 14dd97cb5..15ecf084c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3192,6 +3192,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx bool is_concurrent_event_active = false; ggml_cuda_concurrent_event * concurrent_event = nullptr; + const auto try_launch_concurrent_event = [&](const ggml_tensor * node) { + if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { + concurrent_event = &stream_ctx.concurrent_events[node]; + + GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); + + cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 + GGML_ASSERT(cuda_ctx->curr_stream_no == 0); + CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); + + for (int i = 1; i <= concurrent_event->n_streams; ++i) { + cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); + CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); + } + } + }; + while (!graph_evaluated_or_captured) { // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph. // With the use of CUDA graphs, the execution will be performed by the graph launch. @@ -3212,6 +3229,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (node == concurrent_event->join_node) { cuda_ctx->curr_stream_no = 0; for (int i = 1; i <= concurrent_event->n_streams; ++i) { + // Wait on join events of forked streams in the main stream CUDA_CHECK(cudaEventRecord(concurrent_event->per_stream_events[i - 1], cuda_ctx->stream(cuda_ctx->device, i))); CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->per_stream_events[i - 1])); @@ -3230,24 +3248,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx //the previous node was fused const ggml_tensor * prev_node = cgraph->nodes[i - 1]; - if (stream_ctx.concurrent_events.find(prev_node) != stream_ctx.concurrent_events.end()) { - concurrent_event = &stream_ctx.concurrent_events[prev_node]; - - GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, prev_node->name); - - cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 - GGML_ASSERT(cuda_ctx->curr_stream_no == 0); - CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); - - for (int i = 1; i <= concurrent_event->n_streams; ++i) { - cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); - CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); - } - - is_concurrent_event_active = true; - cuda_ctx->curr_stream_no = concurrent_event->stream_mapping[node]; - GGML_LOG_DEBUG("Setting stream no to %d for node %s\n", cuda_ctx->curr_stream_no, node->name); - } + try_launch_concurrent_event(prev_node); } prev_i = i; @@ -3565,25 +3566,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx GGML_ASSERT(ok); if (!is_concurrent_event_active) { - //const ggml_tensor * adjusted_node = node; - // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD), - // we can safely use the previous node to check if it can be forked - if (stream_ctx.concurrent_events.find(node) != stream_ctx.concurrent_events.end()) { - concurrent_event = &stream_ctx.concurrent_events[node]; - - GGML_LOG_DEBUG("Launching %d streams at %s\n", concurrent_event->n_streams, node->name); - - cudaStream_t main_stream = cuda_ctx->stream(); // this should be stream 0 - GGML_ASSERT(cuda_ctx->curr_stream_no == 0); - CUDA_CHECK(cudaEventRecord(concurrent_event->fork_event, main_stream)); - - for (int i = 1; i <= concurrent_event->n_streams; ++i) { - cudaStream_t stream = cuda_ctx->stream(cuda_ctx->device, i); - CUDA_CHECK(cudaStreamWaitEvent(stream, concurrent_event->fork_event)); - } - - is_concurrent_event_active = true; - } + try_launch_concurrent_event(node); } } } From cfa1a02c54e45818cba6fe9e551aaa7eef6adbd1 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 14 Nov 2025 23:38:01 +0800 Subject: [PATCH 08/10] ggml-cuda: add some more comments --- ggml/src/ggml-cuda/ggml-cuda.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 15ecf084c..7e2729e03 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3189,6 +3189,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); bool is_concurrent_event_active = false; ggml_cuda_concurrent_event * concurrent_event = nullptr; @@ -3215,9 +3216,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (!use_cuda_graph || cuda_graph_update_required) { [[maybe_unused]] int prev_i = 0; - ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); - if (stream_ctx.concurrent_events.size() > 0) { + //Restore the original graph to enable fusion within the streams cgraph->nodes = const_cast(stream_ctx.original_graph.data()); } From d385f760941dcf7cfb3f992b2cb92c8f2aa75d21 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Fri, 14 Nov 2025 23:53:15 +0800 Subject: [PATCH 09/10] ggml-cuda: add more detailed comments about concurrency --- ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 7e2729e03..659ae3fa1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3764,7 +3764,15 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph } } - //Target Q, K, V + // Target Q, K, V for concurrency + // this is a more general way to find nodes which can be candidates for concurrency (although it has not been tested for anything else): + // 1. find fan-out (fork) nodes where the same input is used at least N times (in QKV, it would be "attn-norm") + // 2. find the join node, where 2 or more of the outputs are required (in QKV, this would "KQ" or "flash-attn") + // 3. account for all branches from the fork to the join + // 4. To extend lifetimes of the tensors, we interleave the branches (see below for more details) + // 5. save the original cgraph and restore it in graph_compute, to enable fusion within streams + // See discussion: https://github.com/ggml-org/llama.cpp/pull/16991#issuecomment-3522620030 + const int min_fan_out = 3; const int max_fan_out = 3; From c9b06adc6a46986371315d013543a172127537d0 Mon Sep 17 00:00:00 2001 From: Aman Gupta Date: Sat, 15 Nov 2025 09:23:58 +0800 Subject: [PATCH 10/10] ggml-cuda: rename + remove unused var --- ggml/src/ggml-cuda/common.cuh | 10 ++++------ ggml/src/ggml-cuda/ggml-cuda.cu | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index ca0e54cdd..6025a3cbc 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -966,9 +966,8 @@ struct ggml_cuda_graph { }; struct ggml_cuda_concurrent_event { - std::vector per_stream_events; + std::vector join_events; cudaEvent_t fork_event; - cudaEvent_t join_event; int n_streams = 0; std::unordered_map stream_mapping; @@ -978,14 +977,13 @@ struct ggml_cuda_concurrent_event { ggml_cuda_concurrent_event() = default; explicit ggml_cuda_concurrent_event(int n_streams) : n_streams(n_streams) { - per_stream_events.resize(n_streams); + join_events.resize(n_streams); - for (size_t i = 0; i < per_stream_events.size(); ++i) { - CUDA_CHECK(cudaEventCreateWithFlags(&per_stream_events[i], cudaEventDisableTiming)); + for (size_t i = 0; i < join_events.size(); ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&join_events[i], cudaEventDisableTiming)); } CUDA_CHECK(cudaEventCreateWithFlags(&fork_event, cudaEventDisableTiming)); - CUDA_CHECK(cudaEventCreateWithFlags(&join_event, cudaEventDisableTiming)); } }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 659ae3fa1..84bdfbd73 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3229,10 +3229,10 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx if (node == concurrent_event->join_node) { cuda_ctx->curr_stream_no = 0; for (int i = 1; i <= concurrent_event->n_streams; ++i) { - // Wait on join events of forked streams in the main stream - CUDA_CHECK(cudaEventRecord(concurrent_event->per_stream_events[i - 1], + // Wait on join events of forked streams in the main stream + CUDA_CHECK(cudaEventRecord(concurrent_event->join_events[i - 1], cuda_ctx->stream(cuda_ctx->device, i))); - CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->per_stream_events[i - 1])); + CUDA_CHECK(cudaStreamWaitEvent(cuda_ctx->stream(), concurrent_event->join_events[i - 1])); } is_concurrent_event_active = false;