@@ -3192,6 +3192,23 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
31923192 bool is_concurrent_event_active = false ;
31933193 ggml_cuda_concurrent_event * concurrent_event = nullptr ;
31943194
3195+ const auto try_launch_concurrent_event = [&](const ggml_tensor * node) {
3196+ if (stream_ctx.concurrent_events .find (node) != stream_ctx.concurrent_events .end ()) {
3197+ concurrent_event = &stream_ctx.concurrent_events [node];
3198+
3199+ GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
3200+
3201+ cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3202+ GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3203+ CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
3204+
3205+ for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3206+ cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3207+ CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3208+ }
3209+ }
3210+ };
3211+
31953212 while (!graph_evaluated_or_captured) {
31963213 // Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
31973214 // With the use of CUDA graphs, the execution will be performed by the graph launch.
@@ -3212,6 +3229,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32123229 if (node == concurrent_event->join_node ) {
32133230 cuda_ctx->curr_stream_no = 0 ;
32143231 for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3232+ // Wait on join events of forked streams in the main stream
32153233 CUDA_CHECK (cudaEventRecord (concurrent_event->per_stream_events [i - 1 ],
32163234 cuda_ctx->stream (cuda_ctx->device , i)));
32173235 CUDA_CHECK (cudaStreamWaitEvent (cuda_ctx->stream (), concurrent_event->per_stream_events [i - 1 ]));
@@ -3230,24 +3248,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
32303248
32313249 // the previous node was fused
32323250 const ggml_tensor * prev_node = cgraph->nodes [i - 1 ];
3233- if (stream_ctx.concurrent_events .find (prev_node) != stream_ctx.concurrent_events .end ()) {
3234- concurrent_event = &stream_ctx.concurrent_events [prev_node];
3235-
3236- GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , prev_node->name );
3237-
3238- cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3239- GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3240- CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
3241-
3242- for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3243- cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3244- CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3245- }
3246-
3247- is_concurrent_event_active = true ;
3248- cuda_ctx->curr_stream_no = concurrent_event->stream_mapping [node];
3249- GGML_LOG_DEBUG (" Setting stream no to %d for node %s\n " , cuda_ctx->curr_stream_no , node->name );
3250- }
3251+ try_launch_concurrent_event (prev_node);
32513252 }
32523253 prev_i = i;
32533254
@@ -3565,25 +3566,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
35653566 GGML_ASSERT (ok);
35663567
35673568 if (!is_concurrent_event_active) {
3568- // const ggml_tensor * adjusted_node = node;
3569- // the forking node may have been fused, e.g (RMS_NORM_MUL + MUL + ADD),
3570- // we can safely use the previous node to check if it can be forked
3571- if (stream_ctx.concurrent_events .find (node) != stream_ctx.concurrent_events .end ()) {
3572- concurrent_event = &stream_ctx.concurrent_events [node];
3573-
3574- GGML_LOG_DEBUG (" Launching %d streams at %s\n " , concurrent_event->n_streams , node->name );
3575-
3576- cudaStream_t main_stream = cuda_ctx->stream (); // this should be stream 0
3577- GGML_ASSERT (cuda_ctx->curr_stream_no == 0 );
3578- CUDA_CHECK (cudaEventRecord (concurrent_event->fork_event , main_stream));
3579-
3580- for (int i = 1 ; i <= concurrent_event->n_streams ; ++i) {
3581- cudaStream_t stream = cuda_ctx->stream (cuda_ctx->device , i);
3582- CUDA_CHECK (cudaStreamWaitEvent (stream, concurrent_event->fork_event ));
3583- }
3584-
3585- is_concurrent_event_active = true ;
3586- }
3569+ try_launch_concurrent_event (node);
35873570 }
35883571 }
35893572 }
0 commit comments