@@ -2998,7 +2998,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29982998#ifndef NDEBUG
29992999 const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
30003000 GGML_ASSERT (unary_ops.size () == num_unary);
3001- #endif
3001+ #endif ;
30023002
30033003 // TODO: remove special case once ggml_can_fuse can handle empty nodes
30043004 std::initializer_list<enum ggml_op> topk_moe_ops =
@@ -3139,6 +3139,29 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
31393139 return false ;
31403140}
31413141
3142+ static void reorder_nodes_for_stream_fusion (ggml_backend_cuda_context * cuda_ctx, ggml_tensor ** data, int n_nodes) {
3143+ for (const auto & [fork_node, event] : cuda_ctx->stream_context ()) {
3144+
3145+ const int fork_node_idx = event.fork_node_idx ;
3146+ const int join_node_idx = event.join_node_idx ;
3147+
3148+ for (int i = fork_node_idx + 1 , k = 0 ; i <= join_node_idx - 1 ; i++, k++) {
3149+ data[i] = const_cast <ggml_tensor *>(event.nodes [k]);
3150+ }
3151+ }
3152+ for (const auto & [fork_node, event] : cuda_ctx->stream_context ()) {
3153+ bool found = false ;
3154+ for (int i = 0 ; i < n_nodes; ++i) {
3155+ if (data[i] == fork_node) {
3156+ found = true ;
3157+ break ;
3158+ }
3159+ }
3160+
3161+ GGML_ASSERT (found);
3162+ }
3163+ }
3164+
31423165static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
31433166 bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
31443167 // flag used to determine whether it is an integrated_gpu
@@ -3153,6 +3176,22 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31533176 if (!use_cuda_graph || cuda_graph_update_required) {
31543177 [[maybe_unused]] int prev_i = 0 ;
31553178
3179+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context ();
3180+ GGML_LOG_DEBUG (" Stream ctx size: %d\n " , stream_ctx.size ());
3181+ ggml_tensor ** orig_data = cgraph->nodes ;
3182+ std::vector<ggml_tensor *> orig_graph;
3183+ orig_graph.resize (cgraph->n_nodes );
3184+ if (cuda_graph_update_required) {
3185+ // we are capturing so we can actually re-order
3186+ for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
3187+ orig_graph[i] = cgraph->nodes [i];
3188+ }
3189+ reorder_nodes_for_stream_fusion (cuda_ctx, orig_graph.data (), cgraph->n_nodes );
3190+ GGML_LOG_DEBUG (" Reordered CUDA graph %p %d\n " , cgraph->nodes , cgraph->n_nodes );
3191+ cgraph->nodes = orig_graph.data ();
3192+ }
3193+
3194+
31563195 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
31573196 ggml_tensor * node = cgraph->nodes [i];
31583197 if (is_concurrent_event_active) {
@@ -3188,7 +3227,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31883227 continue ;
31893228 }
31903229
3191- ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context ();
31923230
31933231 // start of fusion operations
31943232 static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
@@ -3497,31 +3535,33 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34973535 GGML_ASSERT (ok);
34983536
34993537 if (!is_concurrent_event_active) {
3500- const ggml_tensor * adjusted_node = node;
3538+ // const ggml_tensor * adjusted_node = node;
35013539 // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD),
35023540 // we can safely use the previous node to check if it can be forked
3503- if (i - prev_i > 1 ) {
3504- adjusted_node = cgraph->nodes [i - 1 ];
3505- }
3506- if (stream_ctx.find (adjusted_node) != stream_ctx.end ()) {
3507- concurrent_event = &stream_ctx[adjusted_node];
3541+ for (int k = prev_i +1 ; k < i; ++k) {
3542+ const ggml_tensor * adjusted_node = cgraph->nodes [k];
3543+ if (stream_ctx.find (adjusted_node) != stream_ctx.end ()) {
3544+ concurrent_event = &stream_ctx[adjusted_node];
35083545
3509- GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
3546+ GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
35103547
3511- cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3512- GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3513- CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
3548+ cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3549+ GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3550+ CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
35143551
3515- for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3516- cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3517- CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3518- }
3552+ for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3553+ cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3554+ CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3555+ }
35193556
3520- is_concurrent_event_active = true ;
3557+ is_concurrent_event_active = true ;
3558+ }
3559+
35213560 }
3522- }
3561+ }
35233562 prev_i = i;
35243563 }
3564+ cgraph->nodes = orig_data;
35253565 }
35263566
35273567#ifdef USE_CUDA_GRAPH
@@ -3673,7 +3713,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
36733713 }
36743714
36753715 GGML_ASSERT (ggml_backend_cuda_get_device_count () == 1 && " cuda graph optimization is only supported on single GPU" );
3676- GGML_LOG_DEBUG (" Optimizing CUDA graph\n " );
3716+ GGML_LOG_DEBUG (" Optimizing CUDA graph %p %d \n " , cgraph-> nodes , cgraph-> n_nodes );
36773717
36783718 ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context ();
36793719 stream_context.clear ();
@@ -3833,8 +3873,17 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38333873 continue ;
38343874 }
38353875
3876+ for (const auto & branch: nodes_per_branch) {
3877+ for (const ggml_tensor * n: branch) {
3878+ concurrent_event.nodes .push_back (n);
3879+ }
3880+ }
3881+ concurrent_event.fork_node_idx = fork_node_idx;
3882+ concurrent_event.join_node_idx = join_node_idx;
3883+
38363884 GGML_ASSERT (cuda_ctx->stream_context ().find (root_node) == cuda_ctx->stream_context ().end ());
38373885 cuda_ctx->stream_context ().emplace (root_node, concurrent_event);
3886+ GGML_LOG_DEBUG (" Adding stream at node %s %p\n " , root_node->name , root_node);
38383887 concurrent_node_ranges.emplace_back (fork_node_idx, join_node_idx);
38393888
38403889 // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
@@ -3850,12 +3899,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38503899 has_node |= branch_node.size () > 0 ;
38513900 }
38523901
3853- if (!has_node) {
3854- printf (" Skipping %s because it is empty %s\n " , cgraph->nodes [current_node_idx]->name ,
3855- ggml_op_name (cgraph->nodes [current_node_idx]->op ));
3856- current_node_idx++;
3857- continue ;
3858- }
3902+ GGML_ASSERT (has_node);
38593903
38603904 if (branch_nodes.empty ()) {
38613905 current_branch_idx = (current_branch_idx + 1 ) % n_branches;
0 commit comments