@@ -3153,6 +3153,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31533153 [[maybe_unused]] int prev_i = 0 ;
31543154
31553155 ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context ();
3156+ if (stream_ctx.concurrent_events .size () > 0 ) {
3157+ cgraph->nodes = const_cast <ggml_tensor **>(stream_ctx.original_graph .data ());
3158+ }
31563159
31573160 for (int i = 0 ; i < cgraph->n_nodes ; i++) {
31583161 ggml_tensor * node = cgraph->nodes [i];
@@ -3176,6 +3179,26 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31763179 cuda_ctx->curr_stream_no = stream_mapping;
31773180 GGML_LOG_DEBUG (" Setting stream no to %d for node %s\n " , stream_mapping, node->name );
31783181 }
3182+ } else if (i - prev_i > 1 ) {
3183+
3184+ // the previous node was fused
3185+ const ggml_tensor * prev_node = cgraph->nodes [i - 1 ];
3186+ if (stream_ctx.concurrent_events .find (prev_node) != stream_ctx.concurrent_events .end ()) {
3187+ concurrent_event = &stream_ctx.concurrent_events [prev_node];
3188+
3189+ GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
3190+
3191+ cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3192+ GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3193+ CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
3194+
3195+ for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3196+ cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3197+ CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3198+ }
3199+
3200+ is_concurrent_event_active = true ;
3201+ }
31793202 }
31803203 prev_i = i;
31813204
@@ -3446,23 +3469,16 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34463469 continue ;
34473470 }
34483471
3449- // TODO: fix this
3450- static const bool graph_opt = (getenv (" GGML_CUDA_GRAPH_OPT" ) != nullptr ) && atoi (getenv (" GGML_CUDA_GRAPH_OPT" )) == 1 ;
3451-
34523472 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ADD}, {})) {
3453- if (strncmp (cgraph->nodes [i+2 ]->name , " attn_norm" , strlen (" attn_norm" )) != 0 || !graph_opt) {
3454- ggml_cuda_op_rms_norm_fused_add (*cuda_ctx, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
3455- i += 2 ;
3456- continue ;
3457- }
3473+ ggml_cuda_op_rms_norm_fused_add (*cuda_ctx, node, cgraph->nodes [i+1 ], cgraph->nodes [i+2 ]);
3474+ i += 2 ;
3475+ continue ;
34583476 }
34593477
34603478 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL}, {})) {
3461- if (strncmp (cgraph->nodes [i+1 ]->name , " attn_norm" , strlen (" attn_norm" )) != 0 || !graph_opt) {
3462- ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
3463- i++;
3464- continue ;
3465- }
3479+ ggml_cuda_op_rms_norm_fused (*cuda_ctx, node, cgraph->nodes [i+1 ]);
3480+ i++;
3481+ continue ;
34663482 }
34673483
34683484 if (ggml_cuda_can_fuse (cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
@@ -3494,8 +3510,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
34943510 // const ggml_tensor * adjusted_node = node;
34953511 // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD),
34963512 // we can safely use the previous node to check if it can be forked
3497- if (stream_ctx.find (node) != stream_ctx.end ()) {
3498- concurrent_event = &stream_ctx[node];
3513+ if (stream_ctx.concurrent_events . find (node) != stream_ctx. concurrent_events .end ()) {
3514+ concurrent_event = &stream_ctx. concurrent_events [node];
34993515
35003516 GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
35013517
@@ -3666,7 +3682,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
36663682 GGML_LOG_DEBUG (" Optimizing CUDA graph %p %d\n " , cgraph->nodes , cgraph->n_nodes );
36673683
36683684 ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context ();
3669- stream_context.clear ();
3685+ stream_context.reset ();
36703686
36713687 std::unordered_map<const ggml_tensor *, int > fan_out;
36723688 std::unordered_map<const ggml_tensor *, int > node_indices;
@@ -3712,6 +3728,15 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
37123728 const int max_fan_out = 3 ;
37133729
37143730 std::vector<std::pair<int , int >> concurrent_node_ranges;
3731+
3732+ // save the original graph
3733+ std::vector<const ggml_tensor *> original_graph;
3734+ original_graph.reserve (cgraph->n_nodes );
3735+ for (int i = 0 ; i < cgraph->n_nodes ; ++i) {
3736+ original_graph.push_back (cgraph->nodes [i]);
3737+ }
3738+ cuda_ctx->stream_context ().original_graph = std::move (original_graph);
3739+
37153740 for (const auto & [root_node, count] : fan_out) {
37163741 if (count >= min_fan_out && count <= max_fan_out) {
37173742 const int root_node_idx = node_indices[root_node];
@@ -3740,7 +3765,7 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
37403765 // find the join point
37413766 const ggml_tensor * join_node = nullptr ;
37423767
3743- auto belongs_to_branch = [&](const ggml_tensor * node, std::vector<const ggml_tensor *> & branch) -> bool {
3768+ const auto & belongs_to_branch = [&](const ggml_tensor * node, std::vector<const ggml_tensor *> & branch) -> bool {
37443769 for (const ggml_tensor * n : branch) {
37453770 if (n == node) {
37463771 return false ;
@@ -3823,16 +3848,16 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
38233848 continue ;
38243849 }
38253850
3826- GGML_ASSERT (cuda_ctx->stream_context ().find (root_node) == cuda_ctx->stream_context ().end ());
3827- cuda_ctx->stream_context ().emplace (root_node, concurrent_event);
3851+ std::unordered_map<const ggml_tensor *, ggml_cuda_concurrent_event> & concurrent_events = cuda_ctx->stream_context ().concurrent_events ;
3852+ GGML_ASSERT (concurrent_events.find (root_node) == concurrent_events.end ());
3853+ concurrent_events.emplace (root_node, concurrent_event);
38283854 GGML_LOG_DEBUG (" Adding stream at node %s %p\n " , root_node->name , root_node);
38293855 concurrent_node_ranges.emplace_back (fork_node_idx, join_node_idx);
38303856
38313857 // interleave tensors to extend lifetimes so that ggml graph doesn't recycle them
38323858 // example transformation:
38333859 // [attn-norm, QMul, QNorm, QRope, KMul, KNorm, KRope, VMul, attn] ->
38343860 // [attn-norm, QMul, KMul, VMul, QNorm, VNorm, QRope, KRope, attn]
3835- // TODO: This breaks fusion within streams, how do we fix this?
38363861 while (current_node_idx < join_node_idx) {
38373862 std::vector<const ggml_tensor *> & branch_nodes = nodes_per_branch[current_branch_idx];
38383863
0 commit comments