@@ -2998,7 +2998,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
29982998#ifndef NDEBUG
29992999 const size_t num_unary = std::count (ops.begin (), ops.end (), GGML_OP_UNARY);
30003000 GGML_ASSERT (unary_ops.size () == num_unary);
3001- #endif ;
3001+ #endif
30023002
30033003 // TODO: remove special case once ggml_can_fuse can handle empty nodes
30043004 std::initializer_list<enum ggml_op> topk_moe_ops =
@@ -3139,29 +3139,6 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
31393139 return false ;
31403140}
31413141
3142- static void reorder_nodes_for_stream_fusion (ggml_backend_cuda_context * cuda_ctx, ggml_tensor ** data, int n_nodes) {
3143- for (const auto & [fork_node, event] : cuda_ctx->stream_context ()) {
3144-
3145- const int fork_node_idx = event.fork_node_idx ;
3146- const int join_node_idx = event.join_node_idx ;
3147-
3148- for (int i = fork_node_idx + 1 , k = 0 ; i <= join_node_idx - 1 ; i++, k++) {
3149- data[i] = const_cast <ggml_tensor *>(event.nodes [k]);
3150- }
3151- }
3152- for (const auto & [fork_node, event] : cuda_ctx->stream_context ()) {
3153- bool found = false ;
3154- for (int i = 0 ; i < n_nodes; ++i) {
3155- if (data[i] == fork_node) {
3156- found = true ;
3157- break ;
3158- }
3159- }
3160-
3161- GGML_ASSERT (found);
3162- }
3163- }
3164-
31653142static void evaluate_and_capture_cuda_graph (ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
31663143 bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
31673144 // flag used to determine whether it is an integrated_gpu
@@ -3176,22 +3153,6 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31763153 if (!use_cuda_graph || cuda_graph_update_required) {
31773154 [[maybe_unused]] int prev_i = 0 ;
31783155
3179- ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context ();
3180- GGML_LOG_DEBUG (" Stream ctx size: %d\n " , stream_ctx.size ());
3181- ggml_tensor ** orig_data = cgraph->nodes ;
3182- std::vector<ggml_tensor *> orig_graph;
3183- orig_graph.resize (cgraph->n_nodes );
3184- if (cuda_graph_update_required) {
3185- // we are capturing so we can actually re-order
3186- for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
3187- orig_graph[i] = cgraph->nodes [i];
3188- }
3189- reorder_nodes_for_stream_fusion (cuda_ctx, orig_graph.data (), cgraph->n_nodes );
3190- GGML_LOG_DEBUG (" Reordered CUDA graph %p %d\n " , cgraph->nodes , cgraph->n_nodes );
3191- cgraph->nodes = orig_graph.data ();
3192- }
3193-
3194-
31953156 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
31963157 ggml_tensor * node = cgraph->nodes [i];
31973158 if (is_concurrent_event_active) {
@@ -3227,6 +3188,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32273188 continue ;
32283189 }
32293190
3191+ ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context ();
32303192
32313193 // start of fusion operations
32323194 static bool disable_fusion = (getenv (" GGML_CUDA_DISABLE_FUSION" ) != nullptr );
@@ -3535,33 +3497,31 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35353497 GGML_ASSERT (ok);
35363498
35373499 if (!is_concurrent_event_active) {
3538- // const ggml_tensor * adjusted_node = node;
3500+ const ggml_tensor * adjusted_node = node;
35393501 // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD),
35403502 // we can safely use the previous node to check if it can be forked
3541- for (int k = prev_i +1 ; k < i; ++k) {
3542- const ggml_tensor * adjusted_node = cgraph->nodes [k];
3543- if (stream_ctx.find (adjusted_node) != stream_ctx.end ()) {
3544- concurrent_event = &stream_ctx[adjusted_node];
3503+ if (i - prev_i > 1 ) {
3504+ adjusted_node = cgraph->nodes [i - 1 ];
3505+ }
3506+ if (stream_ctx.find (adjusted_node) != stream_ctx.end ()) {
3507+ concurrent_event = &stream_ctx[adjusted_node];
35453508
3546- GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
3509+ GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
35473510
3548- cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3549- GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3550- CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
3511+ cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3512+ GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3513+ CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
35513514
3552- for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3553- cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3554- CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3555- }
3556-
3557- is_concurrent_event_active = true ;
3515+ for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3516+ cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3517+ CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
35583518 }
3559-
3519+
3520+ is_concurrent_event_active = true ;
35603521 }
3561- }
3522+ }
35623523 prev_i = i;
35633524 }
3564- cgraph->nodes = orig_data;
35653525 }
35663526
35673527#ifdef USE_CUDA_GRAPH
@@ -3713,7 +3673,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
37133673 }
37143674
37153675 GGML_ASSERT (ggml_backend_cuda_get_device_count () == 1 && " cuda graph optimization is only supported on single GPU" );
3716- GGML_LOG_DEBUG (" Optimizing CUDA graph %p %d \n " , cgraph-> nodes , cgraph-> n_nodes );
3676+ GGML_LOG_DEBUG (" Optimizing CUDA graph\n " );
37173677
37183678 ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context ();
37193679 stream_context.clear ();
@@ -3873,17 +3833,8 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38733833 continue ;
38743834 }
38753835
3876- for (const auto & branch: nodes_per_branch) {
3877- for (const ggml_tensor * n: branch) {
3878- concurrent_event.nodes .push_back (n);
3879- }
3880- }
3881- concurrent_event.fork_node_idx = fork_node_idx;
3882- concurrent_event.join_node_idx = join_node_idx;
3883-
38843836 GGML_ASSERT (cuda_ctx->stream_context ().find (root_node) == cuda_ctx->stream_context ().end ());
38853837 cuda_ctx->stream_context ().emplace (root_node, concurrent_event);
3886- GGML_LOG_DEBUG (" Adding stream at node %s %p\n " , root_node->name , root_node);
38873838 concurrent_node_ranges.emplace_back (fork_node_idx, join_node_idx);
38883839
38893840 // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
@@ -3899,7 +3850,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38993850 has_node |= branch_node.size () > 0 ;
39003851 }
39013852
3902- GGML_ASSERT (has_node);
3853+ if (!has_node) {
3854+ printf (" Skipping %s because it is empty %s\n " , cgraph->nodes [current_node_idx]->name ,
3855+ ggml_op_name (cgraph->nodes [current_node_idx]->op ));
3856+ current_node_idx++;
3857+ continue ;
3858+ }
39033859
39043860 if (branch_nodes.empty ()) {
39053861 current_branch_idx = (current_branch_idx + 1 ) % n_branches;
0 commit comments